Compare commits

...

3 Commits

Author SHA1 Message Date
Guillermo Oyarzun
caf7fdae77 try new vec functions 2024-11-29 09:52:22 +01:00
Guillermo Oyarzun
abab60a6e8 add doublekeybundle 2024-11-26 17:32:05 +01:00
Guillermo Oyarzun
644bac8fd8 remove some syncs 2024-11-25 16:45:24 +01:00
4 changed files with 488 additions and 15 deletions

View File

@@ -305,4 +305,210 @@ __global__ void batch_polynomial_mul(double2 *d_input1, double2 *d_input2,
}
}
template <class params>
__device__ void NSMFFT_direct2(double2 *A, double2 u[params::opt >> 1],
double2 v[params::opt >> 1]) {
/* We don't make bit reverse here, since twiddles are already reversed
* Each thread is always in charge of "opt/2" pairs of coefficients,
* which is why we always loop through N/2 by N/opt strides
* The pragma unroll instruction tells the compiler to unroll the
* full loop, which should increase performance
*/
//__syncthreads();
constexpr Index BUTTERFLY_DEPTH = params::opt >> 1;
constexpr Index LOG2_DEGREE = params::log2_degree;
constexpr Index HALF_DEGREE = params::degree >> 1;
constexpr Index STRIDE = params::degree / params::opt;
Index tid = threadIdx.x;
// double2 u[BUTTERFLY_DEPTH], v[BUTTERFLY_DEPTH], w;
double2 w;
// load into registers
// #pragma unroll
// for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
// u[i] = A[tid];
// v[i] = A[tid + HALF_DEGREE];
// tid += STRIDE;
// }
// level 1
// we don't make actual complex multiplication on level1 since we have only
// one twiddle, it's real and image parts are equal, so we can multiply
// it with simpler operations
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
w = v[i] * (double2){0.707106781186547461715008466854,
0.707106781186547461715008466854};
v[i] = u[i] - w;
u[i] = u[i] + w;
}
Index twiddle_shift = 1;
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
twiddle_shift <<= 1;
tid = threadIdx.x;
// __syncthreads();
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
A[tid] = (u_stays_in_register) ? v[i] : u[i];
tid = tid + STRIDE;
}
__syncthreads();
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
w = A[tid ^ lane_mask];
u[i] = (u_stays_in_register) ? u[i] : w;
v[i] = (u_stays_in_register) ? w : v[i];
w = negtwiddles[tid / lane_mask + twiddle_shift];
w *= v[i];
v[i] = u[i] - w;
u[i] = u[i] + w;
tid = tid + STRIDE;
}
__syncthreads();
}
//__syncthreads();
// store registers in SM
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
A[tid * 2] = u[i];
A[tid * 2 + 1] = v[i];
tid = tid + STRIDE;
}
__syncthreads();
}
template <class params>
__device__ void
NSMFFT_direct2_vec(double2 *A, double2 *B, double2 u[params::opt >> 1],
double2 v[params::opt >> 1], double2 u2[params::opt >> 1],
double2 v2[params::opt >> 1]) {
/* We don't make bit reverse here, since twiddles are already reversed
* Each thread is always in charge of "opt/2" pairs of coefficients,
* which is why we always loop through N/2 by N/opt strides
* The pragma unroll instruction tells the compiler to unroll the
* full loop, which should increase performance
*/
//__syncthreads();
constexpr Index BUTTERFLY_DEPTH = params::opt >> 1;
constexpr Index LOG2_DEGREE = params::log2_degree;
constexpr Index HALF_DEGREE = params::degree >> 1;
constexpr Index STRIDE = params::degree / params::opt;
Index tid = threadIdx.x;
// double2 u[BUTTERFLY_DEPTH], v[BUTTERFLY_DEPTH], w;
double2 w, w2;
// load into registers
// #pragma unroll
// for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
// u[i] = A[tid];
// v[i] = A[tid + HALF_DEGREE];
// tid += STRIDE;
// }
// level 1
// we don't make actual complex multiplication on level1 since we have only
// one twiddle, it's real and image parts are equal, so we can multiply
// it with simpler operations
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
w = v[i] * (double2){0.707106781186547461715008466854,
0.707106781186547461715008466854};
w2 = v2[i] * (double2){0.707106781186547461715008466854,
0.707106781186547461715008466854};
v[i] = u[i] - w;
u[i] = u[i] + w;
v2[i] = u2[i] - w2;
u2[i] = u2[i] + w2;
}
Index twiddle_shift = 1;
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
twiddle_shift <<= 1;
tid = threadIdx.x;
// __syncthreads();
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
A[tid] = (u_stays_in_register) ? v[i] : u[i];
B[tid] = (u_stays_in_register) ? v2[i] : u2[i];
tid = tid + STRIDE;
}
__syncthreads();
// if(l >= 5)
// __syncthreads();
// else
// __syncwarp();
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
w = A[tid ^ lane_mask];
w2 = B[tid ^ lane_mask];
u[i] = (u_stays_in_register) ? u[i] : w;
v[i] = (u_stays_in_register) ? w : v[i];
u2[i] = (u_stays_in_register) ? u2[i] : w2;
v2[i] = (u_stays_in_register) ? w2 : v2[i];
w = negtwiddles[tid / lane_mask + twiddle_shift];
w2 = w * v2[i];
w *= v[i];
v[i] = u[i] - w;
u[i] = u[i] + w;
v2[i] = u2[i] - w2;
u2[i] = u2[i] + w2;
tid = tid + STRIDE;
}
__syncthreads();
// if(l >= 5)
// __syncthreads();
// else
// __syncwarp();
}
//__syncthreads();
// store registers in SM
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
A[tid * 2] = u[i];
A[tid * 2 + 1] = v[i];
B[tid * 2] = u2[i];
B[tid * 2 + 1] = v2[i];
tid = tid + STRIDE;
}
__syncthreads();
}
#endif // GPU_BOOTSTRAP_FFT_CUH

View File

@@ -48,7 +48,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
uint32_t level_count, uint32_t lwe_offset, uint32_t lwe_chunk_size,
uint32_t keybundle_size_per_input, int8_t *device_mem,
uint64_t device_memory_size_per_block) {
__shared__ uint32_t monomial_degrees[8];
extern __shared__ int8_t sharedmem[];
int8_t *selected_memory;
@@ -59,6 +59,189 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
blockIdx.z * gridDim.x * gridDim.y;
selected_memory = &device_mem[block_index * device_memory_size_per_block];
}
double2 *fft = (double2 *)selected_memory;
double2 *fft2 = fft + polynomial_size / 2;
// Ids
uint32_t level_id = blockIdx.z;
uint32_t glwe_id = blockIdx.y; // / (glwe_dimension + 1);
// uint32_t poly_id = 0; // blockIdx.y;// % (glwe_dimension + 1);
uint32_t lwe_iteration = (blockIdx.x % lwe_chunk_size + lwe_offset);
uint32_t input_idx = blockIdx.x / lwe_chunk_size;
if (lwe_iteration < (lwe_dimension / grouping_factor)) {
const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[input_idx] * (lwe_dimension + 1)];
double2 *keybundle = keybundle_array +
// select the input
input_idx * keybundle_size_per_input;
////////////////////////////////////////////////////////////
// Computes all keybundles
uint32_t rev_lwe_iteration =
((lwe_dimension / grouping_factor) - lwe_iteration - 1);
if (threadIdx.x < (1 << grouping_factor)) {
const Torus *lwe_array_group =
block_lwe_array_in + rev_lwe_iteration * grouping_factor;
monomial_degrees[threadIdx.x] = calculates_monomial_degree<Torus, params>(
lwe_array_group, threadIdx.x, grouping_factor);
}
synchronize_threads_in_block();
// ////////////////////////////////
// Keygen guarantees the first term is a constant term of the polynomial, no
// polynomial multiplication required
const Torus *bsk_slice = get_multi_bit_ith_lwe_gth_group_kth_block(
bootstrapping_key, 0, rev_lwe_iteration, glwe_id, level_id,
grouping_factor, 2 * polynomial_size, glwe_dimension, level_count);
const Torus *bsk_poly_ini = bsk_slice; // + poly_id * params::degree;
Torus reg_acc[params::opt];
Torus reg_acc2[params::opt];
// copy_polynomial_in_regs<Torus, params::opt, params::degree /
// params::opt>(
// bsk_poly_ini, reg_acc);
// copy_polynomial_in_regs<Torus, params::opt, params::degree /
// params::opt>(
// bsk_poly_ini + params::degree, reg_acc2);
copy_polynomial_in_regs_vec<Torus, params::opt,
params::degree / params::opt>(
bsk_poly_ini, reg_acc, bsk_poly_ini + params::degree, reg_acc2);
int offset =
get_start_ith_ggsw_offset(polynomial_size, glwe_dimension, level_count);
// Precalculate the monomial degrees and store them in shared memory
// uint32_t *monomial_degrees = (uint32_t *)selected_memory;
// if (threadIdx.x < (1 << grouping_factor)) {
// const Torus *lwe_array_group =
// block_lwe_array_in + rev_lwe_iteration * grouping_factor;
// monomial_degrees[threadIdx.x] = calculates_monomial_degree<Torus,
// params>(
// lwe_array_group, threadIdx.x, grouping_factor);
// }
// synchronize_threads_in_block();
// Accumulate the other terms
for (int g = 1; g < (1 << grouping_factor); g++) {
uint32_t monomial_degree = monomial_degrees[g];
const Torus *bsk_poly = bsk_poly_ini + g * offset;
const Torus *bsk_poly2 = bsk_poly_ini + g * offset + params::degree;
// Multiply by the bsk element
polynomial_product_accumulate_by_monomial_nosync_vec<Torus, params>(
reg_acc, reg_acc2, bsk_poly, bsk_poly2, monomial_degree);
}
// synchronize_threads_in_block(); // needed because we are going to reuse
// the shared memory for the fft
// double2 *fft = (double2 *)selected_memory;
// Move from local memory back to shared memory but as complex
// int tid = threadIdx.x;
// double2 *fft = (double2 *)selected_memory;
// #pragma unroll
// for (int i = 0; i < params::opt / 2; i++) {
// fft[tid] =
// make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
// (double)std::numeric_limits<Torus>::max(),
// __ll2double_rn((int64_t)reg_acc[i + params::opt /
// 2]) /
// (double)std::numeric_limits<Torus>::max());
// tid += params::degree / params::opt;
// }
double2 u[params::opt >> 2];
double2 v[params::opt >> 2];
double2 u2[params::opt >> 2];
double2 v2[params::opt >> 2];
for (int i = 0; i < params::opt / 4; i++) {
u[i] =
make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
(double)std::numeric_limits<Torus>::max(),
__ll2double_rn((int64_t)reg_acc[i + params::opt / 2]) /
(double)std::numeric_limits<Torus>::max());
u2[i] =
make_double2(__ll2double_rn((int64_t)reg_acc2[i]) /
(double)std::numeric_limits<Torus>::max(),
__ll2double_rn((int64_t)reg_acc2[i + params::opt / 2]) /
(double)std::numeric_limits<Torus>::max());
v[i] = make_double2(
__ll2double_rn((int64_t)reg_acc[i + params::opt / 4]) /
(double)std::numeric_limits<Torus>::max(),
__ll2double_rn(
(int64_t)reg_acc[i + params::opt / 2 + params::opt / 4]) /
(double)std::numeric_limits<Torus>::max());
v2[i] = make_double2(
__ll2double_rn((int64_t)reg_acc2[i + params::opt / 4]) /
(double)std::numeric_limits<Torus>::max(),
__ll2double_rn(
(int64_t)reg_acc2[i + params::opt / 2 + params::opt / 4]) /
(double)std::numeric_limits<Torus>::max());
}
// for (int i = 0; i < params::opt / 4; i++) {
// v[i] = make_double2(
// __ll2double_rn((int64_t)reg_acc[i + params::opt / 4]) /
// (double)std::numeric_limits<Torus>::max(),
// __ll2double_rn(
// (int64_t)reg_acc[i + params::opt / 2 + params::opt / 4]) /
// (double)std::numeric_limits<Torus>::max());
// v2[i] = make_double2(
// __ll2double_rn((int64_t)reg_acc2[i + params::opt / 4]) /
// (double)std::numeric_limits<Torus>::max(),
// __ll2double_rn(
// (int64_t)reg_acc2[i + params::opt / 2 + params::opt / 4]) /
// (double)std::numeric_limits<Torus>::max());
// }
NSMFFT_direct2_vec<HalfDegree<params>>(fft, fft2, u, v, u2, v2);
// lwe iteration
auto keybundle_out = get_ith_mask_kth_block(
keybundle, blockIdx.x % lwe_chunk_size, glwe_id, level_id,
polynomial_size, glwe_dimension, level_count);
// auto keybundle_poly = keybundle_out;// + poly_id * params::degree / 2;
copy_polynomial_vec<double2, params::opt / 2, params::degree / params::opt>(
fft, keybundle_out, fft2, keybundle_out + params::degree / 2);
// copy_polynomial<double2, params::opt / 2, params::degree / params::opt>(
// fft, keybundle_out);
// copy_polynomial<double2, params::opt / 2, params::degree / params::opt>(
// fft2, keybundle_out + params::degree / 2);
}
}
template <typename Torus, class params, sharedMemDegree SMD>
__global__ void device_multi_bit_programmable_bootstrap_keybundle_bck(
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes, double2 *keybundle_array,
const Torus *__restrict__ bootstrapping_key, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
uint32_t level_count, uint32_t lwe_offset, uint32_t lwe_chunk_size,
uint32_t keybundle_size_per_input, int8_t *device_mem,
uint64_t device_memory_size_per_block) {
__shared__ uint32_t monomial_degrees[8];
extern __shared__ int8_t sharedmem[];
int8_t *selected_memory;
if constexpr (SMD == FULLSM) {
selected_memory = sharedmem;
} else {
int block_index = blockIdx.x + blockIdx.y * gridDim.x +
blockIdx.z * gridDim.x * gridDim.y;
selected_memory = &device_mem[block_index * device_memory_size_per_block];
}
double2 *fft = (double2 *)selected_memory;
// Ids
uint32_t level_id = blockIdx.z;
@@ -98,7 +281,8 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
get_start_ith_ggsw_offset(polynomial_size, glwe_dimension, level_count);
// Precalculate the monomial degrees and store them in shared memory
uint32_t *monomial_degrees = (uint32_t *)selected_memory;
// uint32_t *monomial_degrees = (uint32_t *)selected_memory;
if (threadIdx.x < (1 << grouping_factor)) {
const Torus *lwe_array_group =
block_lwe_array_in + rev_lwe_iteration * grouping_factor;
@@ -117,23 +301,43 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
polynomial_product_accumulate_by_monomial_nosync<Torus, params>(
reg_acc, bsk_poly, monomial_degree);
}
synchronize_threads_in_block(); // needed because we are going to reuse the
// shared memory for the fft
// synchronize_threads_in_block(); // needed because we are going to reuse
// the shared memory for the fft
// double2 *fft = (double2 *)selected_memory;
// Move from local memory back to shared memory but as complex
int tid = threadIdx.x;
double2 *fft = (double2 *)selected_memory;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
fft[tid] =
// int tid = threadIdx.x;
// double2 *fft = (double2 *)selected_memory;
// #pragma unroll
// for (int i = 0; i < params::opt / 2; i++) {
// fft[tid] =
// make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
// (double)std::numeric_limits<Torus>::max(),
// __ll2double_rn((int64_t)reg_acc[i + params::opt /
// 2]) /
// (double)std::numeric_limits<Torus>::max());
// tid += params::degree / params::opt;
// }
double2 u[params::opt >> 2];
double2 v[params::opt >> 2];
for (int i = 0; i < params::opt / 4; i++) {
u[i] =
make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
(double)std::numeric_limits<Torus>::max(),
__ll2double_rn((int64_t)reg_acc[i + params::opt / 2]) /
(double)std::numeric_limits<Torus>::max());
tid += params::degree / params::opt;
}
NSMFFT_direct<HalfDegree<params>>(fft);
for (int i = 0; i < params::opt / 4; i++) {
v[i] = make_double2(
__ll2double_rn((int64_t)reg_acc[i + params::opt / 4]) /
(double)std::numeric_limits<Torus>::max(),
__ll2double_rn(
(int64_t)reg_acc[i + params::opt / 2 + params::opt / 4]) /
(double)std::numeric_limits<Torus>::max());
}
NSMFFT_direct2<HalfDegree<params>>(fft, u, v);
// lwe iteration
auto keybundle_out = get_ith_mask_kth_block(
@@ -363,7 +567,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
template <typename Torus>
uint64_t get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle(
uint32_t polynomial_size) {
return sizeof(double2) * polynomial_size / 2; // accumulator
return sizeof(double2) * polynomial_size; // / 2; // accumulator
}
template <typename Torus>
uint64_t get_buffer_size_full_sm_multibit_programmable_bootstrap_step_one(
@@ -513,8 +717,12 @@ __host__ void execute_compute_keybundle(
auto keybundle_fft = buffer->keybundle_fft;
// Compute a keybundle
dim3 grid_keybundle(num_samples * chunk_size,
(glwe_dimension + 1) * (glwe_dimension + 1), level_count);
// dim3 grid_keybundle(num_samples * chunk_size,
// (glwe_dimension + 1) * (glwe_dimension + 1),
// level_count);
dim3 grid_keybundle(num_samples * chunk_size, (glwe_dimension + 1),
level_count);
dim3 thds(polynomial_size / params::opt, 1, 1);
if (max_shared_memory < full_sm_keybundle)

View File

@@ -17,6 +17,18 @@ __device__ void copy_polynomial(const T *__restrict__ source, T *dst) {
tid = tid + block_size;
}
}
template <typename T, int elems_per_thread, int block_size>
__device__ void copy_polynomial_vec(const T *__restrict__ source, T *dst,
const T *__restrict__ source2, T *dst2) {
int tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < elems_per_thread; i++) {
dst[tid] = source[tid];
dst2[tid] = source2[tid];
tid = tid + block_size;
}
}
template <typename T, int elems_per_thread, int block_size>
__device__ void copy_polynomial_in_regs(const T *__restrict__ source, T *dst) {
#pragma unroll
@@ -25,6 +37,17 @@ __device__ void copy_polynomial_in_regs(const T *__restrict__ source, T *dst) {
}
}
template <typename T, int elems_per_thread, int block_size>
__device__ void
copy_polynomial_in_regs_vec(const T *__restrict__ source, T *dst,
const T *__restrict__ source2, T *dst2) {
#pragma unroll
for (int i = 0; i < elems_per_thread; i++) {
dst[i] = source[threadIdx.x + i * block_size];
dst2[i] = source2[threadIdx.x + i * block_size];
}
}
/*
* Receives num_poly concatenated polynomials of type T. For each:
*

View File

@@ -130,4 +130,40 @@ __device__ void polynomial_product_accumulate_by_monomial_nosync(
}
}
template <typename T, class params>
__device__ void polynomial_product_accumulate_by_monomial_nosync_vec(
T *result, T *result2, const T *__restrict__ poly,
const T *__restrict__ poly2, uint32_t monomial_degree) {
// monomial_degree \in [0, 2 * params::degree)
int full_cycles_count = monomial_degree / params::degree;
int remainder_degrees = monomial_degree % params::degree;
// Every thread has a fixed position to track instead of "chasing" the
// position
#pragma unroll
for (int i = 0; i < params::opt; i++) {
int pos =
(threadIdx.x + i * (params::degree / params::opt) - monomial_degree) &
(params::degree - 1);
T element = poly[pos];
T element2 = poly2[pos];
T x = SEL(element, -element, full_cycles_count % 2);
T x2 = SEL(element2, -element2, full_cycles_count % 2);
bool condition =
threadIdx.x + i * (params::degree / params::opt) >= remainder_degrees;
x = SEL(-x, x, condition);
x2 = SEL(-x2, x2, condition);
// x = SEL(-x, x,
// threadIdx.x + i * (params::degree / params::opt) >=
// remainder_degrees);
// x2 = SEL(-x2, x2,
// threadIdx.x + i * (params::degree / params::opt) >=
// remainder_degrees);
result[i] += x;
result2[i] += x2;
}
}
#endif // CNCRT_POLYNOMIAL_MATH_H