mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
8 Commits
al/vectori
...
go/chore/t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c8829c337 | ||
|
|
bef89c8b7b | ||
|
|
3dfea472ef | ||
|
|
1caf8bc65a | ||
|
|
994bbdae3f | ||
|
|
014ca36434 | ||
|
|
d9499be011 | ||
|
|
65b956f3fa |
@@ -544,6 +544,11 @@ __device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
|
||||
uint32_t polynomial_size,
|
||||
int glwe_dimension, uint32_t level_count);
|
||||
|
||||
template <typename T, uint32_t polynomial_size, uint32_t glwe_dimension,
|
||||
uint32_t level_count, uint32_t level_id>
|
||||
__device__ const T *get_ith_mask_kth_block_2_2_params(const T *ptr,
|
||||
int iteration, int k);
|
||||
|
||||
template <typename T>
|
||||
__device__ T *get_ith_body_kth_block(T *ptr, int i, int k, int level,
|
||||
uint32_t polynomial_size,
|
||||
|
||||
@@ -137,6 +137,34 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Performs the decomposition for 2_2 params, assumes level_count = 1
|
||||
// this specialized version it is needed if we plan to keep everything in regs
|
||||
template <typename T, class params, uint32_t base_log>
|
||||
__device__ void decompose_and_compress_level_2_2_params(double2 *result,
|
||||
T *state) {
|
||||
constexpr T mask_mod_b = (1ll << base_log) - 1ll;
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
auto input1 = state[i];
|
||||
auto input2 = state[i + params::opt / 2];
|
||||
T res_re = input1 & mask_mod_b;
|
||||
T res_im = input2 & mask_mod_b;
|
||||
|
||||
input1 >>= base_log; // Update state
|
||||
input2 >>= base_log; // Update state
|
||||
|
||||
T carry_re = ((res_re - 1ll) | input1) & res_re;
|
||||
T carry_im = ((res_im - 1ll) | input2) & res_im;
|
||||
carry_re >>= (base_log - 1);
|
||||
carry_im >>= (base_log - 1);
|
||||
|
||||
res_re -= carry_re << base_log;
|
||||
res_im -= carry_im << base_log;
|
||||
|
||||
typecast_torus_to_double(res_re, result[i].x);
|
||||
typecast_torus_to_double(res_im, result[i].y);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__device__ Torus decompose_one(Torus &state, Torus mask_mod_b, int base_log) {
|
||||
Torus res = state & mask_mod_b;
|
||||
|
||||
@@ -91,6 +91,23 @@ __device__ inline T init_decomposer_state(T input, uint32_t base_log,
|
||||
return res - (need_balance << rep_bit_count);
|
||||
}
|
||||
|
||||
template <typename T, uint32_t base_log, uint32_t level_count>
|
||||
__device__ inline T init_decomposer_state_2_2_params(T input) {
|
||||
constexpr T rep_bit_count = level_count * base_log;
|
||||
constexpr T non_rep_bit_count = sizeof(T) * 8 - rep_bit_count;
|
||||
T res = input >> (non_rep_bit_count - 1);
|
||||
T rounding_bit = res & (T)(1);
|
||||
res++;
|
||||
res >>= 1;
|
||||
constexpr T torus_max = scalar_max<T>();
|
||||
constexpr T mod_mask = torus_max >> non_rep_bit_count;
|
||||
res &= mod_mask;
|
||||
T shifted_random = rounding_bit << (rep_bit_count - 1);
|
||||
T need_balance =
|
||||
(((res - (T)(1)) | shifted_random) & res) >> (rep_bit_count - 1);
|
||||
return res - (need_balance << rep_bit_count);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void modulus_switch(T input, T &output,
|
||||
uint32_t log_modulus) {
|
||||
|
||||
@@ -10,6 +10,44 @@ uint32_t cuda_get_device() {
|
||||
|
||||
void cuda_set_device(uint32_t gpu_index) {
|
||||
check_cuda_error(cudaSetDevice(gpu_index));
|
||||
static bool SETUP_MEM_AND_WARMUP = 1;
|
||||
if (SETUP_MEM_AND_WARMUP){
|
||||
const size_t warmup_size = 10L * 1024 * 1024 * 1024; // 10 GB just for testing
|
||||
// Get default memory pool
|
||||
cudaMemPool_t default_pool;
|
||||
check_cuda_error(cudaDeviceGetDefaultMemPool(&default_pool, gpu_index));
|
||||
|
||||
// Enable opportunistic reuse (may be on by default, but explicitly setting it is good practice)
|
||||
int reuse = 1;
|
||||
check_cuda_error(cudaMemPoolSetAttribute(
|
||||
default_pool,
|
||||
cudaMemPoolReuseAllowOpportunistic,
|
||||
&reuse));
|
||||
|
||||
//Prevent memory from being released back to the OS too soon
|
||||
size_t threshold = warmup_size;
|
||||
check_cuda_error(cudaMemPoolSetAttribute(
|
||||
default_pool,
|
||||
cudaMemPoolAttrReleaseThreshold,
|
||||
&threshold));
|
||||
|
||||
// Warm up the pool by allocating and freeing a large block
|
||||
cudaStream_t stream;
|
||||
check_cuda_error(cudaStreamCreate(&stream));
|
||||
|
||||
void* warmup_ptr = nullptr;
|
||||
check_cuda_error(cudaMallocAsync(&warmup_ptr, warmup_size, stream));
|
||||
check_cuda_error(cudaFreeAsync(warmup_ptr, stream));
|
||||
|
||||
// Sync to ensure pool is grown
|
||||
check_cuda_error(cudaStreamSynchronize(stream));
|
||||
|
||||
printf("Default CUDA memory pool warmed up with 10 GB and opportunistic reuse enabled.\n");
|
||||
|
||||
// Clean up
|
||||
check_cuda_error(cudaStreamDestroy(stream));
|
||||
SETUP_MEM_AND_WARMUP = 0;
|
||||
}
|
||||
}
|
||||
|
||||
cudaEvent_t cuda_create_event(uint32_t gpu_index) {
|
||||
|
||||
@@ -63,7 +63,7 @@ template <class params> __device__ void NSMFFT_direct(double2 *A) {
|
||||
}
|
||||
|
||||
Index twiddle_shift = 1;
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 5; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
@@ -96,8 +96,43 @@ template <class params> __device__ void NSMFFT_direct(double2 *A) {
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (Index l = 4; l >= 1; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
|
||||
tid = threadIdx.x;
|
||||
__syncwarp();
|
||||
double2 reg_A[BUTTERFLY_DEPTH];
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
w = negtwiddles[tid / lane_mask + twiddle_shift];
|
||||
|
||||
w *= v[i];
|
||||
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// store registers in SM
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
@@ -109,6 +144,119 @@ template <class params> __device__ void NSMFFT_direct(double2 *A) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/*
|
||||
* negacyclic fft optimized for 2_2 params
|
||||
it uses the twiddles from shared memory for extra performance
|
||||
this is possible cause we know for 2_2 params will have memory available
|
||||
the fft is returned in registers to avoid extra synchronizations
|
||||
*/
|
||||
template <class params>
|
||||
__device__ void NSMFFT_direct_2_2_params(double2 *A, double2 *fft_out,
|
||||
double2 *shared_twiddles) {
|
||||
|
||||
/* We don't make bit reverse here, since twiddles are already reversed
|
||||
* Each thread is always in charge of "opt/2" pairs of coefficients,
|
||||
* which is why we always loop through N/2 by N/opt strides
|
||||
* The pragma unroll instruction tells the compiler to unroll the
|
||||
* full loop, which should increase performance
|
||||
*/
|
||||
|
||||
constexpr Index BUTTERFLY_DEPTH = params::opt >> 1;
|
||||
constexpr Index LOG2_DEGREE = params::log2_degree;
|
||||
constexpr Index HALF_DEGREE = params::degree >> 1;
|
||||
constexpr Index STRIDE = params::degree / params::opt;
|
||||
|
||||
Index tid = threadIdx.x;
|
||||
double2 u[BUTTERFLY_DEPTH], v[BUTTERFLY_DEPTH], w;
|
||||
|
||||
// switch register order
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
u[i] = fft_out[i];
|
||||
v[i] = fft_out[i + params::opt / 2];
|
||||
}
|
||||
|
||||
// level 1
|
||||
// we don't make actual complex multiplication on level1 since we have only
|
||||
// one twiddle, it's real and image parts are equal, so we can multiply
|
||||
// it with simpler operations
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = v[i] * (double2){0.707106781186547461715008466854,
|
||||
0.707106781186547461715008466854};
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
}
|
||||
|
||||
Index twiddle_shift = 1;
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 5; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
|
||||
tid = threadIdx.x;
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
A[tid] = (u_stays_in_register) ? v[i] : u[i];
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = A[tid ^ lane_mask];
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
w = shared_twiddles[tid / lane_mask + twiddle_shift];
|
||||
|
||||
w *= v[i];
|
||||
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
for (Index l = 4; l >= 1; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
|
||||
tid = threadIdx.x;
|
||||
double2 reg_A[BUTTERFLY_DEPTH];
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
w = shared_twiddles[tid / lane_mask + twiddle_shift];
|
||||
|
||||
w *= v[i];
|
||||
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
// Return result in registers, no need to synchronize here
|
||||
// only with we need to use the same shared memory afterwards
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
fft_out[i] = u[i];
|
||||
fft_out[i + params::opt / 2] = v[i];
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* negacyclic inverse fft
|
||||
*/
|
||||
@@ -144,7 +292,46 @@ template <class params> __device__ void NSMFFT_inverse(double2 *A) {
|
||||
}
|
||||
|
||||
Index twiddle_shift = DEGREE;
|
||||
for (Index l = 1; l <= LOG2_DEGREE - 1; ++l) {
|
||||
for (Index l = 1; l <= 4; ++l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
tid = threadIdx.x;
|
||||
twiddle_shift >>= 1;
|
||||
|
||||
// at this point registers are ready for the butterfly
|
||||
tid = threadIdx.x;
|
||||
__syncwarp();
|
||||
double2 reg_A[BUTTERFLY_DEPTH];
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = (u[i] - v[i]);
|
||||
u[i] += v[i];
|
||||
v[i] = w * conjugate(negtwiddles[tid / lane_mask + twiddle_shift]);
|
||||
|
||||
// keep one of the register for next iteration and store another one in sm
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// prepare registers for next butterfly iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
for (Index l = 5; l <= LOG2_DEGREE - 1; ++l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
tid = threadIdx.x;
|
||||
@@ -201,6 +388,126 @@ template <class params> __device__ void NSMFFT_inverse(double2 *A) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/*
|
||||
* negacyclic inverse fft optimized for 2_2 params
|
||||
* it uses the twiddles from shared memory for extra performance
|
||||
* this is possible cause we know for 2_2 params will have memory available
|
||||
* the input comes from registers to avoid some synchronizations and shared mem
|
||||
* usage
|
||||
*/
|
||||
template <class params>
|
||||
__device__ void NSMFFT_inverse_2_2_params(double2 *A, double2 *buffer_regs,
|
||||
double2 *shared_twiddles) {
|
||||
|
||||
/* We don't make bit reverse here, since twiddles are already reversed
|
||||
* Each thread is always in charge of "opt/2" pairs of coefficients,
|
||||
* which is why we always loop through N/2 by N/opt strides
|
||||
* The pragma unroll instruction tells the compiler to unroll the
|
||||
* full loop, which should increase performance
|
||||
*/
|
||||
|
||||
constexpr Index BUTTERFLY_DEPTH = params::opt >> 1;
|
||||
constexpr Index LOG2_DEGREE = params::log2_degree;
|
||||
constexpr Index DEGREE = params::degree;
|
||||
constexpr Index HALF_DEGREE = params::degree >> 1;
|
||||
constexpr Index STRIDE = params::degree / params::opt;
|
||||
|
||||
size_t tid = threadIdx.x;
|
||||
double2 u[BUTTERFLY_DEPTH], v[BUTTERFLY_DEPTH], w;
|
||||
|
||||
// load into registers and divide by compressed polynomial size
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
u[i] = buffer_regs[i];
|
||||
v[i] = buffer_regs[i + params::opt / 2];
|
||||
|
||||
u[i] /= DEGREE;
|
||||
v[i] /= DEGREE;
|
||||
}
|
||||
|
||||
Index twiddle_shift = DEGREE;
|
||||
for (Index l = 1; l <= 4; ++l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
tid = threadIdx.x;
|
||||
twiddle_shift >>= 1;
|
||||
|
||||
// at this point registers are ready for the butterfly
|
||||
tid = threadIdx.x;
|
||||
double2 reg_A[BUTTERFLY_DEPTH];
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = (u[i] - v[i]);
|
||||
u[i] += v[i];
|
||||
v[i] = w * conjugate(shared_twiddles[tid / lane_mask + twiddle_shift]);
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// prepare registers for next butterfly iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
for (Index l = 5; l <= LOG2_DEGREE - 1; ++l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
tid = threadIdx.x;
|
||||
twiddle_shift >>= 1;
|
||||
|
||||
// at this point registers are ready for the butterfly
|
||||
tid = threadIdx.x;
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = (u[i] - v[i]);
|
||||
u[i] += v[i];
|
||||
v[i] = w * conjugate(shared_twiddles[tid / lane_mask + twiddle_shift]);
|
||||
|
||||
// keep one of the register for next iteration and store another one in sm
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
A[tid] = (u_stays_in_register) ? v[i] : u[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// prepare registers for next butterfly iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = A[tid ^ lane_mask];
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
// last iteration
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = (u[i] - v[i]);
|
||||
buffer_regs[i] = u[i] + v[i];
|
||||
buffer_regs[i + params::opt / 2] =
|
||||
w * (double2){0.707106781186547461715008466854,
|
||||
-0.707106781186547461715008466854};
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* global batch fft
|
||||
* does fft in half size
|
||||
|
||||
@@ -17,6 +17,12 @@ __device__ inline int get_start_ith_ggsw(int i, uint32_t polynomial_size,
|
||||
return i * polynomial_size / 2 * (glwe_dimension + 1) * (glwe_dimension + 1) *
|
||||
level_count;
|
||||
}
|
||||
template <uint32_t polynomial_size, uint32_t glwe_dimension,
|
||||
uint32_t level_count>
|
||||
__device__ inline int get_start_ith_ggsw_2_2_params(int i) {
|
||||
return i * polynomial_size / 2 * (glwe_dimension + 1) * (glwe_dimension + 1) *
|
||||
level_count;
|
||||
}
|
||||
|
||||
__device__ inline int get_start_ith_ggsw_128(int i, uint32_t polynomial_size,
|
||||
int glwe_dimension,
|
||||
@@ -49,6 +55,17 @@ __device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
|
||||
k * polynomial_size / 2 * (glwe_dimension + 1)];
|
||||
}
|
||||
|
||||
template <typename T, uint32_t polynomial_size, uint32_t glwe_dimension,
|
||||
uint32_t level_count, uint32_t level_id>
|
||||
__device__ const T *get_ith_mask_kth_block_2_2_params(const T *ptr,
|
||||
int iteration, int k) {
|
||||
return &ptr[get_start_ith_ggsw_2_2_params<polynomial_size, glwe_dimension,
|
||||
level_count>(iteration) +
|
||||
(level_count - level_id - 1) * polynomial_size / 2 *
|
||||
(glwe_dimension + 1) * (glwe_dimension + 1) +
|
||||
k * polynomial_size / 2 * (glwe_dimension + 1)];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ const T *
|
||||
get_ith_mask_kth_block_128(const T *ptr, int i, int k, int level,
|
||||
|
||||
@@ -22,6 +22,16 @@ get_join_buffer_element(int level_id, int glwe_id, G &group,
|
||||
double2 *global_memory_buffer, uint32_t polynomial_size,
|
||||
uint32_t glwe_dimension, bool support_dsm);
|
||||
|
||||
template <typename G, uint32_t level_id, uint32_t glwe_dimension>
|
||||
__device__ __forceinline__ double2 *
|
||||
get_join_buffer_element_tbc(int glwe_id, G &cluster,
|
||||
double2 *shared_memory_buffer) {
|
||||
double2 *buffer_slice;
|
||||
buffer_slice = cluster.map_shared_rank(
|
||||
shared_memory_buffer, glwe_id + level_id * (glwe_dimension + 1));
|
||||
return buffer_slice;
|
||||
}
|
||||
|
||||
template <typename G>
|
||||
__device__ double *get_join_buffer_element_128(
|
||||
int level_id, int glwe_id, G &group, double *global_memory_buffer,
|
||||
@@ -139,6 +149,59 @@ __device__ void mul_ggsw_glwe_in_fourier_domain_128(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/** Perform the matrix multiplication between the GGSW and the GLWE,
|
||||
* each block operating on a single level for mask and body.
|
||||
* Both operands should be at fourier domain
|
||||
*
|
||||
* This function assumes that 2_2 params are used:
|
||||
* - Thread blocks at dimension z relates to the decomposition level.
|
||||
* - Thread blocks at dimension y relates to the glwe dimension.
|
||||
* - polynomial_size / params::opt threads are available per block
|
||||
* - local fft is read from registers
|
||||
* To avoid a cluster synchronization the accumulator output is different than
|
||||
* the input, and next iteration are switched to act as a ping pong buffer.
|
||||
*/
|
||||
template <typename G, class params, uint32_t polynomial_size,
|
||||
uint32_t glwe_dimension, uint32_t level_count>
|
||||
__device__ void mul_ggsw_glwe_in_fourier_domain_2_2_params(
|
||||
double2 *fft, double2 *fft_regs, double2 *buffer_regs,
|
||||
const double2 *__restrict__ bootstrapping_key, int iteration, G &group,
|
||||
int this_block_rank) {
|
||||
// Continues multiplying fft by every polynomial in that particular bsk level
|
||||
// Each y-block accumulates in a different polynomial at each iteration
|
||||
// We accumulate in registers to free shared memory
|
||||
// In 2_2 params we only have one level
|
||||
constexpr uint32_t level_id = 0;
|
||||
// The first product doesn't need using dsm
|
||||
auto bsk_slice =
|
||||
get_ith_mask_kth_block_2_2_params<double2, polynomial_size,
|
||||
glwe_dimension, level_count, level_id>(
|
||||
bootstrapping_key, iteration, this_block_rank);
|
||||
auto bsk_poly = bsk_slice + blockIdx.y * polynomial_size / 2;
|
||||
polynomial_product_accumulate_in_fourier_domain_2_2_params<params, double2,
|
||||
true>(
|
||||
buffer_regs, fft_regs, bsk_poly);
|
||||
|
||||
// Synchronize to ensure all blocks have written its fft result
|
||||
group.sync();
|
||||
constexpr uint32_t glwe_id = 1;
|
||||
int idx = (glwe_id + this_block_rank) % (glwe_dimension + 1);
|
||||
bsk_slice =
|
||||
get_ith_mask_kth_block_2_2_params<double2, polynomial_size,
|
||||
glwe_dimension, level_count, level_id>(
|
||||
bootstrapping_key, iteration, idx);
|
||||
bsk_poly = bsk_slice + blockIdx.y * polynomial_size / 2;
|
||||
auto fft_slice =
|
||||
get_join_buffer_element_tbc<G, level_id, glwe_dimension>(idx, group, fft);
|
||||
polynomial_product_accumulate_in_fourier_domain_2_2_params<params, double2,
|
||||
false>(
|
||||
buffer_regs, fft_slice, bsk_poly);
|
||||
|
||||
// We don't need to synchronize here, cause we are going to use a buffer
|
||||
// different than the input In 2_2 params, level_count=1 so we can just return
|
||||
// the buffer in registers to avoid synchronizations and shared memory usage
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
void execute_pbs_async(
|
||||
cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
|
||||
@@ -415,8 +415,7 @@ uint64_t scratch_cuda_multi_bit_programmable_bootstrap_64(
|
||||
input_lwe_ciphertext_count, glwe_dimension, polynomial_size,
|
||||
level_count, cuda_get_max_shared_memory(gpu_index));
|
||||
|
||||
if (supports_tbc &&
|
||||
!(input_lwe_ciphertext_count > num_sms / 2 && supports_cg))
|
||||
if (supports_tbc)
|
||||
return scratch_cuda_tbc_multi_bit_programmable_bootstrap<uint64_t>(
|
||||
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
|
||||
glwe_dimension, polynomial_size, level_count,
|
||||
@@ -489,6 +488,17 @@ uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,
|
||||
int log2_max_num_pbs = log2_int(max_num_pbs);
|
||||
if (log2_max_num_pbs > 13)
|
||||
ith_divisor = log2_max_num_pbs - 11;
|
||||
#else
|
||||
// When having few samples we are interested in using a larger chunksize so
|
||||
// the keybundle can saturate the GPU. To obtain homogeneous waves we use half
|
||||
// of the sms as the chunksize, by doing so we always get a multiple of the
|
||||
// number of sms, removing the tailing effect. We don't divide by 4 because
|
||||
// some flavors of H100 might not have a number of sms divisible by 4. This is
|
||||
// applied only to few number of samples(8) because it can have a negative
|
||||
// effect of over saturation.
|
||||
if (max_num_pbs <= 8) {
|
||||
return num_sms / 2;
|
||||
}
|
||||
#endif
|
||||
|
||||
for (int i = sqrt(x); i >= 1; i--) {
|
||||
|
||||
@@ -146,6 +146,144 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
|
||||
}
|
||||
}
|
||||
|
||||
// Calculates the keybundles for 2_2 params
|
||||
// Lwe Dimension = 920
|
||||
// Polynomial Size = 2048
|
||||
// Grouping factor = 4
|
||||
// Glwe dimension = 1
|
||||
// PBS level = 1
|
||||
// In this initial version everything is hardcoded as constexpr, we
|
||||
// will wrap it up in a nicer/cleaner version in the future.
|
||||
// Additionally, we initialize an int8_t vector with coefficients used in the
|
||||
// monomial multiplication The size of this vector is 3x2048 and the
|
||||
// coefficients are: [0 .. 2047] = -1 [2048 .. 4095] = 1 [4096 .. 6143] = -11
|
||||
// Then we can just calculate the offset needed to apply this coefficients, and
|
||||
// the operation transforms into a pointwise vector multiplication, avoiding to
|
||||
// perform extra instructions other than MADD
|
||||
template <typename Torus, class params, sharedMemDegree SMD>
|
||||
__global__ void device_multi_bit_programmable_bootstrap_keybundle_2_2_params(
|
||||
const Torus *__restrict__ lwe_array_in,
|
||||
const Torus *__restrict__ lwe_input_indexes, double2 *keybundle_array,
|
||||
const Torus *__restrict__ bootstrapping_key, uint32_t lwe_offset,
|
||||
uint32_t lwe_chunk_size, uint32_t keybundle_size_per_input) {
|
||||
|
||||
constexpr uint32_t lwe_dimension = 920;
|
||||
constexpr uint32_t polynomial_size = 2048;
|
||||
constexpr uint32_t grouping_factor = 4;
|
||||
constexpr uint32_t glwe_dimension = 1;
|
||||
constexpr uint32_t level_count = 1;
|
||||
|
||||
extern __shared__ int8_t sharedmem[];
|
||||
int8_t *selected_memory;
|
||||
selected_memory = sharedmem;
|
||||
|
||||
int8_t *precalc_coefs =
|
||||
selected_memory + (sizeof(uint32_t) * (1 << grouping_factor));
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
precalc_coefs[threadIdx.x + i * (params::degree / params::opt)] = -1;
|
||||
precalc_coefs[threadIdx.x + i * (params::degree / params::opt) +
|
||||
params::degree] = 1;
|
||||
precalc_coefs[threadIdx.x + i * (params::degree / params::opt) +
|
||||
2 * params::degree] = -1;
|
||||
}
|
||||
|
||||
double2 *shared_fft = (double2 *)(precalc_coefs + polynomial_size * 3);
|
||||
double2 *shared_twiddles = shared_fft + (polynomial_size / 2);
|
||||
for (int k = 0; k < params::opt / 2; k++) {
|
||||
shared_twiddles[threadIdx.x + k * (params::degree / params::opt)] =
|
||||
negtwiddles[threadIdx.x + k * (params::degree / params::opt)];
|
||||
}
|
||||
|
||||
// Ids
|
||||
constexpr uint32_t level_id = 0;
|
||||
uint32_t glwe_id = blockIdx.y / (glwe_dimension + 1);
|
||||
uint32_t poly_id = blockIdx.y % (glwe_dimension + 1);
|
||||
uint32_t lwe_iteration = (blockIdx.x % lwe_chunk_size + lwe_offset);
|
||||
uint32_t input_idx = blockIdx.x / lwe_chunk_size;
|
||||
|
||||
if (lwe_iteration < (lwe_dimension / grouping_factor)) {
|
||||
|
||||
const Torus *block_lwe_array_in =
|
||||
&lwe_array_in[lwe_input_indexes[input_idx] * (lwe_dimension + 1)];
|
||||
|
||||
double2 *keybundle = keybundle_array +
|
||||
// select the input
|
||||
input_idx * keybundle_size_per_input;
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
// Computes all keybundles
|
||||
uint32_t rev_lwe_iteration =
|
||||
((lwe_dimension / grouping_factor) - lwe_iteration - 1);
|
||||
|
||||
// ////////////////////////////////
|
||||
// Keygen guarantees the first term is a constant term of the polynomial, no
|
||||
// polynomial multiplication required
|
||||
const Torus *bsk_slice = get_multi_bit_ith_lwe_gth_group_kth_block(
|
||||
bootstrapping_key, 0, rev_lwe_iteration, glwe_id, level_id,
|
||||
grouping_factor, 2 * polynomial_size, glwe_dimension, level_count);
|
||||
const Torus *bsk_poly_ini = bsk_slice + poly_id * params::degree;
|
||||
|
||||
Torus reg_acc[params::opt];
|
||||
|
||||
copy_polynomial_in_regs<Torus, params::opt, params::degree / params::opt>(
|
||||
bsk_poly_ini, reg_acc);
|
||||
|
||||
constexpr int offset = polynomial_size * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * level_count;
|
||||
// Precalculate the monomial degrees and store them in shared memory
|
||||
uint32_t *monomial_degrees = (uint32_t *)selected_memory;
|
||||
|
||||
if (threadIdx.x < (1 << grouping_factor)) {
|
||||
const Torus *lwe_array_group =
|
||||
block_lwe_array_in + rev_lwe_iteration * grouping_factor;
|
||||
monomial_degrees[threadIdx.x] = calculates_monomial_degree<Torus, params>(
|
||||
lwe_array_group, threadIdx.x, grouping_factor);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Accumulate the other terms
|
||||
for (int g = 1; g < (1 << grouping_factor); g++) {
|
||||
|
||||
uint32_t monomial_degree = monomial_degrees[g];
|
||||
|
||||
int full_cycles_count = monomial_degree / params::degree;
|
||||
int remainder_degrees = monomial_degree % params::degree;
|
||||
int jump = full_cycles_count * params::degree + params::degree -
|
||||
remainder_degrees;
|
||||
|
||||
const Torus *bsk_poly = bsk_poly_ini + g * offset;
|
||||
// Multiply by the bsk element
|
||||
polynomial_accumulate_monic_monomial_mul_on_regs_precalc<Torus, params>(
|
||||
reg_acc, bsk_poly, precalc_coefs + jump, monomial_degree);
|
||||
}
|
||||
|
||||
// Move from local memory back to shared memory but as complex
|
||||
double2 fft_regs[params::opt / 2];
|
||||
double2 *fft = shared_fft;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
fft_regs[i] =
|
||||
make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
|
||||
(double)std::numeric_limits<Torus>::max(),
|
||||
__ll2double_rn((int64_t)reg_acc[i + params::opt / 2]) /
|
||||
(double)std::numeric_limits<Torus>::max());
|
||||
}
|
||||
|
||||
NSMFFT_direct_2_2_params<HalfDegree<params>>(fft, fft_regs,
|
||||
shared_twiddles);
|
||||
|
||||
// lwe iteration
|
||||
auto keybundle_out = get_ith_mask_kth_block(
|
||||
keybundle, blockIdx.x % lwe_chunk_size, glwe_id, level_id,
|
||||
polynomial_size, glwe_dimension, level_count);
|
||||
|
||||
auto keybundle_poly = keybundle_out + poly_id * params::degree / 2;
|
||||
copy_polynomial_from_regs<double2, params::opt / 2,
|
||||
params::degree / params::opt>(fft_regs,
|
||||
keybundle_poly);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus, class params, sharedMemDegree SMD, bool is_first_iter>
|
||||
__global__ void __launch_bounds__(params::degree / params::opt)
|
||||
device_multi_bit_programmable_bootstrap_accumulate_step_one(
|
||||
@@ -562,20 +700,45 @@ __host__ void execute_compute_keybundle(
|
||||
(glwe_dimension + 1) * (glwe_dimension + 1), level_count);
|
||||
dim3 thds(polynomial_size / params::opt, 1, 1);
|
||||
|
||||
if (max_shared_memory < full_sm_keybundle)
|
||||
if (max_shared_memory < full_sm_keybundle) {
|
||||
device_multi_bit_programmable_bootstrap_keybundle<Torus, params, NOSM>
|
||||
<<<grid_keybundle, thds, 0, stream>>>(
|
||||
lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key,
|
||||
lwe_dimension, glwe_dimension, polynomial_size, grouping_factor,
|
||||
level_count, lwe_offset, chunk_size, keybundle_size_per_input,
|
||||
d_mem, full_sm_keybundle);
|
||||
else
|
||||
device_multi_bit_programmable_bootstrap_keybundle<Torus, params, FULLSM>
|
||||
<<<grid_keybundle, thds, full_sm_keybundle, stream>>>(
|
||||
lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key,
|
||||
lwe_dimension, glwe_dimension, polynomial_size, grouping_factor,
|
||||
level_count, lwe_offset, chunk_size, keybundle_size_per_input,
|
||||
d_mem, 0);
|
||||
} else {
|
||||
bool supports_tbc =
|
||||
has_support_to_cuda_programmable_bootstrap_tbc_multi_bit<uint64_t>(
|
||||
num_samples, glwe_dimension, polynomial_size, level_count,
|
||||
cuda_get_max_shared_memory(gpu_index));
|
||||
|
||||
if (supports_tbc && polynomial_size == 2048 && grouping_factor == 4 &&
|
||||
level_count == 1 && glwe_dimension == 1 && lwe_dimension == 920) {
|
||||
dim3 thds_new_keybundle(512, 1, 1);
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_keybundle_2_2_params<
|
||||
Torus, Degree<2048>, FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, 3 * full_sm_keybundle));
|
||||
cudaFuncSetCacheConfig(
|
||||
device_multi_bit_programmable_bootstrap_keybundle_2_2_params<
|
||||
Torus, Degree<2048>, FULLSM>,
|
||||
cudaFuncCachePreferShared);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
device_multi_bit_programmable_bootstrap_keybundle_2_2_params<
|
||||
Torus, Degree<2048>, FULLSM><<<grid_keybundle, thds_new_keybundle,
|
||||
3 * full_sm_keybundle, stream>>>(
|
||||
lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key,
|
||||
lwe_offset, chunk_size, keybundle_size_per_input);
|
||||
} else {
|
||||
device_multi_bit_programmable_bootstrap_keybundle<Torus, params, FULLSM>
|
||||
<<<grid_keybundle, thds, full_sm_keybundle, stream>>>(
|
||||
lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key,
|
||||
lwe_dimension, glwe_dimension, polynomial_size, grouping_factor,
|
||||
level_count, lwe_offset, chunk_size, keybundle_size_per_input,
|
||||
d_mem, 0);
|
||||
}
|
||||
}
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
@@ -181,6 +181,208 @@ __global__ void __launch_bounds__(params::degree / params::opt)
|
||||
}
|
||||
}
|
||||
|
||||
// Specialized version for the multi-bit bootstrap using 2_2 params:
|
||||
// Polynomial size = 2048
|
||||
// PBS level = 1
|
||||
// Grouping factor = 4
|
||||
// PBS base = 22
|
||||
// Glwe dimension = 1
|
||||
// At the moment everything is hardcoded as constexpr, but later
|
||||
// we will generate a cleaner/nicer way handle it.
|
||||
// Main optimizations:
|
||||
//- Leverage shared memory to reduce one cluster synchronization. A
|
||||
// ping pong buffer is used for that, so everything is synchronized
|
||||
// automatically after 2 iterations
|
||||
//- Move everything to registers to avoid shared memory synchronizations
|
||||
//- Use a register based fft that uses the minimal synchronizations
|
||||
//- Register based fourier domain multiplication. Transfer fft's between blocks
|
||||
// instead of accumulator.
|
||||
template <typename Torus, class params, sharedMemDegree SMD>
|
||||
__global__ void __launch_bounds__(params::degree / params::opt)
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params(
|
||||
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
|
||||
const Torus *__restrict__ lut_vector,
|
||||
const Torus *__restrict__ lut_vector_indexes,
|
||||
const Torus *__restrict__ lwe_array_in,
|
||||
const Torus *__restrict__ lwe_input_indexes,
|
||||
const double2 *__restrict__ keybundle_array, double2 *join_buffer,
|
||||
Torus *global_accumulator, uint32_t lwe_dimension, uint32_t lwe_offset,
|
||||
uint32_t lwe_chunk_size, uint32_t keybundle_size_per_input,
|
||||
uint32_t num_many_lut, uint32_t lut_stride) {
|
||||
|
||||
constexpr uint32_t level_count = 1;
|
||||
constexpr uint32_t grouping_factor = 4;
|
||||
constexpr uint32_t polynomial_size = 2048;
|
||||
constexpr uint32_t glwe_dimension = 1;
|
||||
constexpr uint32_t base_log = 22;
|
||||
cluster_group cluster = this_cluster();
|
||||
auto this_block_rank = cluster.block_index().y;
|
||||
// We use shared memory for the polynomials that are used often during the
|
||||
// bootstrap, since shared memory is kept in L1 cache and accessing it is
|
||||
// much faster than global memory
|
||||
extern __shared__ int8_t sharedmem[];
|
||||
int8_t *selected_memory;
|
||||
|
||||
// When using 2_2 params and tbc we know everything fits in shared memory
|
||||
// The first (polynomial_size/2) * sizeof(double2) bytes are reserved for
|
||||
// external product using distributed shared memory
|
||||
selected_memory = sharedmem;
|
||||
// We know that dsm is supported and we have enough memory
|
||||
constexpr uint32_t num_buffers_ping_pong = 2;
|
||||
selected_memory += sizeof(Torus) * polynomial_size * num_buffers_ping_pong;
|
||||
|
||||
double2 *accumulator_ping = (double2 *)sharedmem;
|
||||
double2 *accumulator_pong = accumulator_ping + (polynomial_size / 2);
|
||||
double2 *shared_twiddles = accumulator_pong + (polynomial_size / 2);
|
||||
double2 *shared_fft = shared_twiddles + (polynomial_size / 2);
|
||||
// accumulator rotated shares the same memory space than the twiddles.
|
||||
// it is only used during the sample extract so it is safe to use it
|
||||
Torus *accumulator_rotated = (Torus *)selected_memory;
|
||||
|
||||
// Copying the twiddles from global to shared for extra performance
|
||||
for (int k = 0; k < params::opt / 2; k++) {
|
||||
shared_twiddles[threadIdx.x + k * (params::degree / params::opt)] =
|
||||
negtwiddles[threadIdx.x + k * (params::degree / params::opt)];
|
||||
}
|
||||
|
||||
// The first dimension of the block is used to determine on which ciphertext
|
||||
// this block is operating, in the case of batch bootstraps
|
||||
const Torus *block_lwe_array_in =
|
||||
&lwe_array_in[lwe_input_indexes[blockIdx.x] * (lwe_dimension + 1)];
|
||||
|
||||
const Torus *block_lut_vector =
|
||||
&lut_vector[lut_vector_indexes[blockIdx.x] * params::degree *
|
||||
(glwe_dimension + 1)];
|
||||
|
||||
Torus *global_accumulator_slice =
|
||||
&global_accumulator[(blockIdx.y + blockIdx.x * (glwe_dimension + 1)) *
|
||||
params::degree];
|
||||
|
||||
const double2 *keybundle =
|
||||
&keybundle_array[blockIdx.x * keybundle_size_per_input];
|
||||
|
||||
// The acc rotated is moved to registers to free shared memory for other
|
||||
// potential improvements. itself this change doesn't report much benefit.
|
||||
Torus reg_acc_rotated[params::opt];
|
||||
if (lwe_offset == 0) {
|
||||
// Put "b" in [0, 2N[
|
||||
Torus b_hat = 0;
|
||||
modulus_switch(block_lwe_array_in[lwe_dimension], b_hat,
|
||||
params::log2_degree + 1);
|
||||
|
||||
divide_by_monomial_negacyclic_2_2_params_inplace<
|
||||
Torus, params::opt, params::degree / params::opt>(
|
||||
reg_acc_rotated, &block_lut_vector[blockIdx.y * params::degree], b_hat);
|
||||
} else {
|
||||
// Load the accumulator calculated in previous iterations
|
||||
copy_polynomial_in_regs<Torus, params::opt, params::degree / params::opt>(
|
||||
global_accumulator_slice, reg_acc_rotated);
|
||||
}
|
||||
|
||||
for (int i = 0; (i + lwe_offset) < lwe_dimension && i < lwe_chunk_size; i++) {
|
||||
// Perform a rounding to increase the accuracy of the
|
||||
// bootstrapped ciphertext
|
||||
init_decomposer_state_inplace_2_2_params<Torus, params::opt,
|
||||
params::degree / params::opt,
|
||||
base_log, level_count>(
|
||||
reg_acc_rotated);
|
||||
|
||||
// This is the ping pong buffer logic to avoid a cluster synchronization
|
||||
auto accumulator_fft = i % 2 ? accumulator_ping : accumulator_pong;
|
||||
|
||||
double2 fft_out_regs[params::opt / 2];
|
||||
// Decompose the accumulator. Each block gets one level of the
|
||||
// decomposition, for the mask and the body (so block 0 will have the
|
||||
// accumulator decomposed at level 0, 1 at 1, etc.)
|
||||
decompose_and_compress_level_2_2_params<Torus, params, base_log>(
|
||||
fft_out_regs, reg_acc_rotated);
|
||||
|
||||
NSMFFT_direct_2_2_params<HalfDegree<params>>(shared_fft, fft_out_regs,
|
||||
shared_twiddles);
|
||||
// we move registers into shared memory to use dsm
|
||||
int tid = threadIdx.x;
|
||||
for (Index k = 0; k < params::opt / 4; k++) {
|
||||
accumulator_fft[tid] = fft_out_regs[k];
|
||||
accumulator_fft[tid + params::degree / 4] =
|
||||
fft_out_regs[k + params::opt / 4];
|
||||
tid = tid + params::degree / params::opt;
|
||||
}
|
||||
|
||||
double2 buffer_regs[params::opt / 2];
|
||||
// Perform G^-1(ACC) * GGSW -> GLWE
|
||||
mul_ggsw_glwe_in_fourier_domain_2_2_params<
|
||||
cluster_group, params, polynomial_size, glwe_dimension, level_count>(
|
||||
accumulator_fft, fft_out_regs, buffer_regs, keybundle, i, cluster,
|
||||
this_block_rank);
|
||||
|
||||
NSMFFT_inverse_2_2_params<HalfDegree<params>>(shared_fft, buffer_regs,
|
||||
shared_twiddles);
|
||||
|
||||
add_to_torus_2_2_params<Torus, params>(buffer_regs, reg_acc_rotated);
|
||||
}
|
||||
|
||||
if (lwe_offset + lwe_chunk_size >= (lwe_dimension / grouping_factor)) {
|
||||
|
||||
// Temporary copy to keep the other logic as it is
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
accumulator_rotated[threadIdx.x + i * (params::degree / params::opt)] =
|
||||
reg_acc_rotated[i];
|
||||
}
|
||||
__syncthreads();
|
||||
auto accumulator = accumulator_rotated;
|
||||
auto block_lwe_array_out =
|
||||
&lwe_array_out[lwe_output_indexes[blockIdx.x] *
|
||||
(glwe_dimension * polynomial_size + 1) +
|
||||
blockIdx.y * polynomial_size];
|
||||
|
||||
if (blockIdx.y < glwe_dimension) {
|
||||
// Perform a sample extract. At this point, all blocks have the result,
|
||||
// but we do the computation at block 0 to avoid waiting for extra
|
||||
// blocks, in case they're not synchronized
|
||||
sample_extract_mask<Torus, params>(block_lwe_array_out, accumulator);
|
||||
|
||||
if (num_many_lut > 1) {
|
||||
for (int i = 1; i < num_many_lut; i++) {
|
||||
auto next_lwe_array_out =
|
||||
lwe_array_out +
|
||||
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
|
||||
auto next_block_lwe_array_out =
|
||||
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
|
||||
(glwe_dimension * polynomial_size + 1) +
|
||||
blockIdx.y * polynomial_size];
|
||||
|
||||
sample_extract_mask<Torus, params>(next_block_lwe_array_out,
|
||||
accumulator, 1, i * lut_stride);
|
||||
}
|
||||
}
|
||||
} else if (blockIdx.y == glwe_dimension) {
|
||||
sample_extract_body<Torus, params>(block_lwe_array_out, accumulator, 0);
|
||||
if (num_many_lut > 1) {
|
||||
for (int i = 1; i < num_many_lut; i++) {
|
||||
|
||||
auto next_lwe_array_out =
|
||||
lwe_array_out +
|
||||
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
|
||||
auto next_block_lwe_array_out =
|
||||
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
|
||||
(glwe_dimension * polynomial_size + 1) +
|
||||
blockIdx.y * polynomial_size];
|
||||
|
||||
sample_extract_body<Torus, params>(next_block_lwe_array_out,
|
||||
accumulator, 0, i * lut_stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Load the accumulator calculated in previous iterations
|
||||
copy_polynomial_from_regs<Torus, params::opt, params::degree / params::opt>(
|
||||
reg_acc_rotated, global_accumulator_slice);
|
||||
}
|
||||
// Before exiting the kernel we need to sync the cluster to ensure that
|
||||
// that other blocks can still access the dsm in the ping pong buffer
|
||||
cluster.sync();
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
uint64_t get_buffer_size_sm_dsm_plus_tbc_multibit_programmable_bootstrap(
|
||||
uint32_t polynomial_size) {
|
||||
@@ -271,15 +473,32 @@ __host__ uint64_t scratch_tbc_multi_bit_programmable_bootstrap(
|
||||
cudaFuncCachePreferShared);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
} else {
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
|
||||
FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
full_sm_tbc_accumulate + minimum_sm_tbc_accumulate));
|
||||
cudaFuncSetCacheConfig(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
|
||||
FULLSM>,
|
||||
cudaFuncCachePreferShared);
|
||||
if (polynomial_size == 2048 && level_count == 1 && glwe_dimension == 1) {
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
full_sm_tbc_accumulate + 2 * minimum_sm_tbc_accumulate));
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout,
|
||||
cudaSharedmemCarveoutMaxShared));
|
||||
check_cuda_error(cudaFuncSetCacheConfig(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
cudaFuncCachePreferShared));
|
||||
} else {
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
|
||||
FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
full_sm_tbc_accumulate + minimum_sm_tbc_accumulate));
|
||||
cudaFuncSetCacheConfig(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
|
||||
FULLSM>,
|
||||
cudaFuncCachePreferShared);
|
||||
}
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
@@ -382,16 +601,44 @@ __host__ void execute_tbc_external_product_loop(
|
||||
lut_stride));
|
||||
} else {
|
||||
config.dynamicSmemBytes = full_dm + minimum_dm;
|
||||
check_cuda_error(cudaLaunchKernelEx(
|
||||
&config,
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
|
||||
FULLSM>,
|
||||
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
|
||||
lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft,
|
||||
global_accumulator, lwe_dimension, glwe_dimension, polynomial_size,
|
||||
base_log, level_count, grouping_factor, lwe_offset, chunk_size,
|
||||
keybundle_size_per_input, d_mem, 0, supports_dsm, num_many_lut,
|
||||
lut_stride));
|
||||
if (polynomial_size == 2048 && grouping_factor == 4 && level_count == 1 &&
|
||||
glwe_dimension == 1 && base_log == 22) {
|
||||
|
||||
config.dynamicSmemBytes = full_dm + 2 * minimum_dm;
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
full_dm + 2 * minimum_dm));
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout,
|
||||
cudaSharedmemCarveoutMaxShared));
|
||||
check_cuda_error(cudaFuncSetCacheConfig(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
cudaFuncCachePreferShared));
|
||||
check_cuda_error(cudaLaunchKernelEx(
|
||||
&config,
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
|
||||
lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft,
|
||||
global_accumulator, lwe_dimension, lwe_offset, chunk_size,
|
||||
keybundle_size_per_input, num_many_lut, lut_stride));
|
||||
} else {
|
||||
check_cuda_error(cudaLaunchKernelEx(
|
||||
&config,
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
|
||||
FULLSM>,
|
||||
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
|
||||
lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft,
|
||||
global_accumulator, lwe_dimension, glwe_dimension, polynomial_size,
|
||||
base_log, level_count, grouping_factor, lwe_offset, chunk_size,
|
||||
keybundle_size_per_input, d_mem, 0, supports_dsm, num_many_lut,
|
||||
lut_stride));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -505,15 +752,27 @@ __host__ bool supports_thread_block_clusters_on_multibit_programmable_bootstrap(
|
||||
PARTIALSM>,
|
||||
&config));
|
||||
} else {
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
|
||||
FULLSM>,
|
||||
cudaFuncAttributeNonPortableClusterSizeAllowed, false));
|
||||
check_cuda_error(cudaOccupancyMaxPotentialClusterSize(
|
||||
&cluster_size,
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
|
||||
FULLSM>,
|
||||
&config));
|
||||
if (polynomial_size == 2048 && level_count == 1 && glwe_dimension == 1) {
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
cudaFuncAttributeNonPortableClusterSizeAllowed, false));
|
||||
check_cuda_error(cudaOccupancyMaxPotentialClusterSize(
|
||||
&cluster_size,
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
&config));
|
||||
} else {
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
|
||||
FULLSM>,
|
||||
cudaFuncAttributeNonPortableClusterSizeAllowed, false));
|
||||
check_cuda_error(cudaOccupancyMaxPotentialClusterSize(
|
||||
&cluster_size,
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
|
||||
FULLSM>,
|
||||
&config));
|
||||
}
|
||||
}
|
||||
|
||||
return cluster_size >= level_count * (glwe_dimension + 1);
|
||||
|
||||
@@ -26,6 +26,15 @@ __device__ void copy_polynomial_in_regs(const T *__restrict__ source, T *dst) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int elems_per_thread, int block_size>
|
||||
__device__ void copy_polynomial_from_regs(const T *__restrict__ source,
|
||||
T *dst) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
dst[threadIdx.x + i * block_size] = source[i];
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Receives num_poly concatenated polynomials of type T. For each:
|
||||
*
|
||||
@@ -80,6 +89,36 @@ divide_by_monomial_negacyclic_inplace(T *accumulator,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Receives num_poly concatenated polynomials of type T. For each:
|
||||
*
|
||||
* Performs acc = acc * (X^ä + 1) if zeroAcc = false
|
||||
* Performs acc = 0 if zeroAcc
|
||||
* takes single buffer and calculates inplace.
|
||||
*
|
||||
* By default, it works on a single polynomial.
|
||||
*/
|
||||
template <typename T, int elems_per_thread, int block_size>
|
||||
__device__ void divide_by_monomial_negacyclic_2_2_params_inplace(
|
||||
T *accumulator, const T *__restrict__ input, uint32_t j) {
|
||||
constexpr int degree = block_size * elems_per_thread;
|
||||
int tid = threadIdx.x;
|
||||
if (j < degree) {
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
int x = tid + j - SEL(degree, 0, tid < degree - j);
|
||||
accumulator[i] = SEL(-1, 1, tid < degree - j) * input[x];
|
||||
tid += block_size;
|
||||
}
|
||||
} else {
|
||||
int32_t jj = j - degree;
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
int x = tid + jj - SEL(degree, 0, tid < degree - jj);
|
||||
accumulator[i] = SEL(1, -1, tid < degree - jj) * input[x];
|
||||
tid += block_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Receives num_poly concatenated polynomials of type T. For each:
|
||||
*
|
||||
@@ -143,6 +182,22 @@ __device__ void init_decomposer_state_inplace(T *rotated_acc, int base_log,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Receives num_poly concatenated polynomials of type T. For each performs a
|
||||
* rounding to increase accuracy of the PBS. Calculates inplace.
|
||||
*
|
||||
* By default, it works on a single polynomial.
|
||||
*/
|
||||
template <typename T, int elems_per_thread, uint32_t block_size,
|
||||
uint32_t base_log, int level_count>
|
||||
__device__ void init_decomposer_state_inplace_2_2_params(T *rotated_acc) {
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
T x_acc = rotated_acc[i];
|
||||
rotated_acc[i] =
|
||||
init_decomposer_state_2_2_params<T, base_log, level_count>(x_acc);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* In case of classical PBS, this method should accumulate the result.
|
||||
* In case of multi-bit PBS, it should overwrite.
|
||||
@@ -173,6 +228,31 @@ __device__ void add_to_torus(double2 *m_values, Torus *result,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* In case of classical PBS, this method should accumulate the result.
|
||||
* In case of multi-bit PBS, it should overwrite.
|
||||
*/
|
||||
template <typename Torus, class params>
|
||||
__device__ void add_to_torus_2_2_params(double2 *m_values, Torus *result) {
|
||||
int tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
double double_real = m_values[i].x;
|
||||
double double_imag = m_values[i].y;
|
||||
|
||||
Torus torus_real = 0;
|
||||
typecast_double_round_to_torus<Torus>(double_real, torus_real);
|
||||
|
||||
Torus torus_imag = 0;
|
||||
typecast_double_round_to_torus<Torus>(double_imag, torus_imag);
|
||||
|
||||
result[i] = torus_real;
|
||||
result[i + params::opt / 2] = torus_imag;
|
||||
|
||||
tid = tid + params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* In case of classical PBS, this method should accumulate the result.
|
||||
* In case of multi-bit PBS, it should overwrite.
|
||||
|
||||
@@ -61,6 +61,29 @@ __device__ void polynomial_product_accumulate_in_fourier_domain(
|
||||
}
|
||||
}
|
||||
|
||||
// Computes result += first * second
|
||||
// If init_accumulator is set, assumes that result was not initialized and does
|
||||
// that with the outcome of first * second
|
||||
// The result is always in registers and if init_accumulator true
|
||||
// the first is also in registers this is tuned for 2_2 params
|
||||
template <class params, typename T, bool init_accumulator>
|
||||
__device__ void polynomial_product_accumulate_in_fourier_domain_2_2_params(
|
||||
T *__restrict__ result, T *__restrict__ first,
|
||||
const T *__restrict__ second) {
|
||||
int tid = threadIdx.x;
|
||||
if constexpr (init_accumulator) {
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
result[i] = first[i] * __ldg(&second[tid]);
|
||||
tid += (params::degree / params::opt);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
result[i] += first[tid] * __ldg(&second[tid]);
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Computes result += first * second
|
||||
// If init_accumulator is set, assumes that result was not initialized and does
|
||||
// that with the outcome of first * second
|
||||
@@ -231,4 +254,24 @@ __device__ void polynomial_accumulate_monic_monomial_mul_on_regs(
|
||||
}
|
||||
}
|
||||
|
||||
// Does the same as polynomial_accumulate_monic_monomial_mul() but result is
|
||||
// being written to registers and coefficients are precalculated
|
||||
template <typename T, class params>
|
||||
__device__ void polynomial_accumulate_monic_monomial_mul_on_regs_precalc(
|
||||
T *result, const T *__restrict__ poly, int8_t *coefs,
|
||||
uint32_t monomial_degree) {
|
||||
// Every thread has a fixed position to track instead of "chasing" the
|
||||
// position
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
int pos =
|
||||
(threadIdx.x + i * (params::degree / params::opt) - monomial_degree) &
|
||||
(params::degree - 1);
|
||||
|
||||
T element = poly[pos];
|
||||
result[i] +=
|
||||
coefs[threadIdx.x + i * (params::degree / params::opt)] * element;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // CNCRT_POLYNOMIAL_MATH_H
|
||||
|
||||
@@ -74,4 +74,11 @@ __device__ inline double2 operator*(double a, double2 b) {
|
||||
return {__dmul_rn(b.x, a), __dmul_rn(b.y, a)};
|
||||
}
|
||||
|
||||
__device__ inline double2 shfl_xor_double2(double2 val, int laneMask,
|
||||
unsigned mask = 0xFFFFFFFF) {
|
||||
double re = __shfl_xor_sync(mask, val.x, laneMask);
|
||||
double im = __shfl_xor_sync(mask, val.y, laneMask);
|
||||
|
||||
return make_double2(re, im);
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user