mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
3 Commits
tm/flip
...
go/refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
caf7fdae77 | ||
|
|
abab60a6e8 | ||
|
|
644bac8fd8 |
@@ -305,4 +305,210 @@ __global__ void batch_polynomial_mul(double2 *d_input1, double2 *d_input2,
|
||||
}
|
||||
}
|
||||
|
||||
template <class params>
|
||||
__device__ void NSMFFT_direct2(double2 *A, double2 u[params::opt >> 1],
|
||||
double2 v[params::opt >> 1]) {
|
||||
|
||||
/* We don't make bit reverse here, since twiddles are already reversed
|
||||
* Each thread is always in charge of "opt/2" pairs of coefficients,
|
||||
* which is why we always loop through N/2 by N/opt strides
|
||||
* The pragma unroll instruction tells the compiler to unroll the
|
||||
* full loop, which should increase performance
|
||||
*/
|
||||
|
||||
//__syncthreads();
|
||||
constexpr Index BUTTERFLY_DEPTH = params::opt >> 1;
|
||||
constexpr Index LOG2_DEGREE = params::log2_degree;
|
||||
constexpr Index HALF_DEGREE = params::degree >> 1;
|
||||
constexpr Index STRIDE = params::degree / params::opt;
|
||||
|
||||
Index tid = threadIdx.x;
|
||||
// double2 u[BUTTERFLY_DEPTH], v[BUTTERFLY_DEPTH], w;
|
||||
double2 w;
|
||||
// load into registers
|
||||
// #pragma unroll
|
||||
// for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
// u[i] = A[tid];
|
||||
// v[i] = A[tid + HALF_DEGREE];
|
||||
|
||||
// tid += STRIDE;
|
||||
// }
|
||||
|
||||
// level 1
|
||||
// we don't make actual complex multiplication on level1 since we have only
|
||||
// one twiddle, it's real and image parts are equal, so we can multiply
|
||||
// it with simpler operations
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = v[i] * (double2){0.707106781186547461715008466854,
|
||||
0.707106781186547461715008466854};
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
}
|
||||
|
||||
Index twiddle_shift = 1;
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
|
||||
tid = threadIdx.x;
|
||||
// __syncthreads();
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
A[tid] = (u_stays_in_register) ? v[i] : u[i];
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = A[tid ^ lane_mask];
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
w = negtwiddles[tid / lane_mask + twiddle_shift];
|
||||
|
||||
w *= v[i];
|
||||
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
//__syncthreads();
|
||||
|
||||
// store registers in SM
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
A[tid * 2] = u[i];
|
||||
A[tid * 2 + 1] = v[i];
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <class params>
|
||||
__device__ void
|
||||
NSMFFT_direct2_vec(double2 *A, double2 *B, double2 u[params::opt >> 1],
|
||||
double2 v[params::opt >> 1], double2 u2[params::opt >> 1],
|
||||
double2 v2[params::opt >> 1]) {
|
||||
|
||||
/* We don't make bit reverse here, since twiddles are already reversed
|
||||
* Each thread is always in charge of "opt/2" pairs of coefficients,
|
||||
* which is why we always loop through N/2 by N/opt strides
|
||||
* The pragma unroll instruction tells the compiler to unroll the
|
||||
* full loop, which should increase performance
|
||||
*/
|
||||
|
||||
//__syncthreads();
|
||||
constexpr Index BUTTERFLY_DEPTH = params::opt >> 1;
|
||||
constexpr Index LOG2_DEGREE = params::log2_degree;
|
||||
constexpr Index HALF_DEGREE = params::degree >> 1;
|
||||
constexpr Index STRIDE = params::degree / params::opt;
|
||||
|
||||
Index tid = threadIdx.x;
|
||||
// double2 u[BUTTERFLY_DEPTH], v[BUTTERFLY_DEPTH], w;
|
||||
double2 w, w2;
|
||||
// load into registers
|
||||
// #pragma unroll
|
||||
// for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
// u[i] = A[tid];
|
||||
// v[i] = A[tid + HALF_DEGREE];
|
||||
|
||||
// tid += STRIDE;
|
||||
// }
|
||||
|
||||
// level 1
|
||||
// we don't make actual complex multiplication on level1 since we have only
|
||||
// one twiddle, it's real and image parts are equal, so we can multiply
|
||||
// it with simpler operations
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = v[i] * (double2){0.707106781186547461715008466854,
|
||||
0.707106781186547461715008466854};
|
||||
w2 = v2[i] * (double2){0.707106781186547461715008466854,
|
||||
0.707106781186547461715008466854};
|
||||
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
|
||||
v2[i] = u2[i] - w2;
|
||||
u2[i] = u2[i] + w2;
|
||||
}
|
||||
|
||||
Index twiddle_shift = 1;
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
|
||||
tid = threadIdx.x;
|
||||
// __syncthreads();
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
A[tid] = (u_stays_in_register) ? v[i] : u[i];
|
||||
B[tid] = (u_stays_in_register) ? v2[i] : u2[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
// if(l >= 5)
|
||||
// __syncthreads();
|
||||
// else
|
||||
// __syncwarp();
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = A[tid ^ lane_mask];
|
||||
w2 = B[tid ^ lane_mask];
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
u2[i] = (u_stays_in_register) ? u2[i] : w2;
|
||||
v2[i] = (u_stays_in_register) ? w2 : v2[i];
|
||||
|
||||
w = negtwiddles[tid / lane_mask + twiddle_shift];
|
||||
w2 = w * v2[i];
|
||||
w *= v[i];
|
||||
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
|
||||
v2[i] = u2[i] - w2;
|
||||
u2[i] = u2[i] + w2;
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
// if(l >= 5)
|
||||
// __syncthreads();
|
||||
// else
|
||||
// __syncwarp();
|
||||
}
|
||||
//__syncthreads();
|
||||
|
||||
// store registers in SM
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
A[tid * 2] = u[i];
|
||||
A[tid * 2 + 1] = v[i];
|
||||
B[tid * 2] = u2[i];
|
||||
B[tid * 2 + 1] = v2[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#endif // GPU_BOOTSTRAP_FFT_CUH
|
||||
|
||||
@@ -48,7 +48,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
|
||||
uint32_t level_count, uint32_t lwe_offset, uint32_t lwe_chunk_size,
|
||||
uint32_t keybundle_size_per_input, int8_t *device_mem,
|
||||
uint64_t device_memory_size_per_block) {
|
||||
|
||||
__shared__ uint32_t monomial_degrees[8];
|
||||
extern __shared__ int8_t sharedmem[];
|
||||
int8_t *selected_memory;
|
||||
|
||||
@@ -59,6 +59,189 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
|
||||
blockIdx.z * gridDim.x * gridDim.y;
|
||||
selected_memory = &device_mem[block_index * device_memory_size_per_block];
|
||||
}
|
||||
double2 *fft = (double2 *)selected_memory;
|
||||
double2 *fft2 = fft + polynomial_size / 2;
|
||||
// Ids
|
||||
uint32_t level_id = blockIdx.z;
|
||||
uint32_t glwe_id = blockIdx.y; // / (glwe_dimension + 1);
|
||||
// uint32_t poly_id = 0; // blockIdx.y;// % (glwe_dimension + 1);
|
||||
uint32_t lwe_iteration = (blockIdx.x % lwe_chunk_size + lwe_offset);
|
||||
uint32_t input_idx = blockIdx.x / lwe_chunk_size;
|
||||
|
||||
if (lwe_iteration < (lwe_dimension / grouping_factor)) {
|
||||
|
||||
const Torus *block_lwe_array_in =
|
||||
&lwe_array_in[lwe_input_indexes[input_idx] * (lwe_dimension + 1)];
|
||||
|
||||
double2 *keybundle = keybundle_array +
|
||||
// select the input
|
||||
input_idx * keybundle_size_per_input;
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
// Computes all keybundles
|
||||
uint32_t rev_lwe_iteration =
|
||||
((lwe_dimension / grouping_factor) - lwe_iteration - 1);
|
||||
|
||||
if (threadIdx.x < (1 << grouping_factor)) {
|
||||
const Torus *lwe_array_group =
|
||||
block_lwe_array_in + rev_lwe_iteration * grouping_factor;
|
||||
monomial_degrees[threadIdx.x] = calculates_monomial_degree<Torus, params>(
|
||||
lwe_array_group, threadIdx.x, grouping_factor);
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// ////////////////////////////////
|
||||
// Keygen guarantees the first term is a constant term of the polynomial, no
|
||||
// polynomial multiplication required
|
||||
const Torus *bsk_slice = get_multi_bit_ith_lwe_gth_group_kth_block(
|
||||
bootstrapping_key, 0, rev_lwe_iteration, glwe_id, level_id,
|
||||
grouping_factor, 2 * polynomial_size, glwe_dimension, level_count);
|
||||
const Torus *bsk_poly_ini = bsk_slice; // + poly_id * params::degree;
|
||||
|
||||
Torus reg_acc[params::opt];
|
||||
Torus reg_acc2[params::opt];
|
||||
|
||||
// copy_polynomial_in_regs<Torus, params::opt, params::degree /
|
||||
// params::opt>(
|
||||
// bsk_poly_ini, reg_acc);
|
||||
|
||||
// copy_polynomial_in_regs<Torus, params::opt, params::degree /
|
||||
// params::opt>(
|
||||
// bsk_poly_ini + params::degree, reg_acc2);
|
||||
|
||||
copy_polynomial_in_regs_vec<Torus, params::opt,
|
||||
params::degree / params::opt>(
|
||||
bsk_poly_ini, reg_acc, bsk_poly_ini + params::degree, reg_acc2);
|
||||
|
||||
int offset =
|
||||
get_start_ith_ggsw_offset(polynomial_size, glwe_dimension, level_count);
|
||||
|
||||
// Precalculate the monomial degrees and store them in shared memory
|
||||
// uint32_t *monomial_degrees = (uint32_t *)selected_memory;
|
||||
|
||||
// if (threadIdx.x < (1 << grouping_factor)) {
|
||||
// const Torus *lwe_array_group =
|
||||
// block_lwe_array_in + rev_lwe_iteration * grouping_factor;
|
||||
// monomial_degrees[threadIdx.x] = calculates_monomial_degree<Torus,
|
||||
// params>(
|
||||
// lwe_array_group, threadIdx.x, grouping_factor);
|
||||
// }
|
||||
// synchronize_threads_in_block();
|
||||
|
||||
// Accumulate the other terms
|
||||
for (int g = 1; g < (1 << grouping_factor); g++) {
|
||||
|
||||
uint32_t monomial_degree = monomial_degrees[g];
|
||||
|
||||
const Torus *bsk_poly = bsk_poly_ini + g * offset;
|
||||
const Torus *bsk_poly2 = bsk_poly_ini + g * offset + params::degree;
|
||||
|
||||
// Multiply by the bsk element
|
||||
polynomial_product_accumulate_by_monomial_nosync_vec<Torus, params>(
|
||||
reg_acc, reg_acc2, bsk_poly, bsk_poly2, monomial_degree);
|
||||
}
|
||||
// synchronize_threads_in_block(); // needed because we are going to reuse
|
||||
// the shared memory for the fft
|
||||
// double2 *fft = (double2 *)selected_memory;
|
||||
// Move from local memory back to shared memory but as complex
|
||||
// int tid = threadIdx.x;
|
||||
// double2 *fft = (double2 *)selected_memory;
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < params::opt / 2; i++) {
|
||||
// fft[tid] =
|
||||
// make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
|
||||
// (double)std::numeric_limits<Torus>::max(),
|
||||
// __ll2double_rn((int64_t)reg_acc[i + params::opt /
|
||||
// 2]) /
|
||||
// (double)std::numeric_limits<Torus>::max());
|
||||
// tid += params::degree / params::opt;
|
||||
// }
|
||||
double2 u[params::opt >> 2];
|
||||
double2 v[params::opt >> 2];
|
||||
|
||||
double2 u2[params::opt >> 2];
|
||||
double2 v2[params::opt >> 2];
|
||||
for (int i = 0; i < params::opt / 4; i++) {
|
||||
u[i] =
|
||||
make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
|
||||
(double)std::numeric_limits<Torus>::max(),
|
||||
__ll2double_rn((int64_t)reg_acc[i + params::opt / 2]) /
|
||||
(double)std::numeric_limits<Torus>::max());
|
||||
u2[i] =
|
||||
make_double2(__ll2double_rn((int64_t)reg_acc2[i]) /
|
||||
(double)std::numeric_limits<Torus>::max(),
|
||||
__ll2double_rn((int64_t)reg_acc2[i + params::opt / 2]) /
|
||||
(double)std::numeric_limits<Torus>::max());
|
||||
v[i] = make_double2(
|
||||
__ll2double_rn((int64_t)reg_acc[i + params::opt / 4]) /
|
||||
(double)std::numeric_limits<Torus>::max(),
|
||||
__ll2double_rn(
|
||||
(int64_t)reg_acc[i + params::opt / 2 + params::opt / 4]) /
|
||||
(double)std::numeric_limits<Torus>::max());
|
||||
v2[i] = make_double2(
|
||||
__ll2double_rn((int64_t)reg_acc2[i + params::opt / 4]) /
|
||||
(double)std::numeric_limits<Torus>::max(),
|
||||
__ll2double_rn(
|
||||
(int64_t)reg_acc2[i + params::opt / 2 + params::opt / 4]) /
|
||||
(double)std::numeric_limits<Torus>::max());
|
||||
}
|
||||
|
||||
// for (int i = 0; i < params::opt / 4; i++) {
|
||||
// v[i] = make_double2(
|
||||
// __ll2double_rn((int64_t)reg_acc[i + params::opt / 4]) /
|
||||
// (double)std::numeric_limits<Torus>::max(),
|
||||
// __ll2double_rn(
|
||||
// (int64_t)reg_acc[i + params::opt / 2 + params::opt / 4]) /
|
||||
// (double)std::numeric_limits<Torus>::max());
|
||||
// v2[i] = make_double2(
|
||||
// __ll2double_rn((int64_t)reg_acc2[i + params::opt / 4]) /
|
||||
// (double)std::numeric_limits<Torus>::max(),
|
||||
// __ll2double_rn(
|
||||
// (int64_t)reg_acc2[i + params::opt / 2 + params::opt / 4]) /
|
||||
// (double)std::numeric_limits<Torus>::max());
|
||||
|
||||
// }
|
||||
|
||||
NSMFFT_direct2_vec<HalfDegree<params>>(fft, fft2, u, v, u2, v2);
|
||||
|
||||
// lwe iteration
|
||||
auto keybundle_out = get_ith_mask_kth_block(
|
||||
keybundle, blockIdx.x % lwe_chunk_size, glwe_id, level_id,
|
||||
polynomial_size, glwe_dimension, level_count);
|
||||
// auto keybundle_poly = keybundle_out;// + poly_id * params::degree / 2;
|
||||
|
||||
copy_polynomial_vec<double2, params::opt / 2, params::degree / params::opt>(
|
||||
fft, keybundle_out, fft2, keybundle_out + params::degree / 2);
|
||||
|
||||
// copy_polynomial<double2, params::opt / 2, params::degree / params::opt>(
|
||||
// fft, keybundle_out);
|
||||
|
||||
// copy_polynomial<double2, params::opt / 2, params::degree / params::opt>(
|
||||
// fft2, keybundle_out + params::degree / 2);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus, class params, sharedMemDegree SMD>
|
||||
__global__ void device_multi_bit_programmable_bootstrap_keybundle_bck(
|
||||
const Torus *__restrict__ lwe_array_in,
|
||||
const Torus *__restrict__ lwe_input_indexes, double2 *keybundle_array,
|
||||
const Torus *__restrict__ bootstrapping_key, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
|
||||
uint32_t level_count, uint32_t lwe_offset, uint32_t lwe_chunk_size,
|
||||
uint32_t keybundle_size_per_input, int8_t *device_mem,
|
||||
uint64_t device_memory_size_per_block) {
|
||||
__shared__ uint32_t monomial_degrees[8];
|
||||
extern __shared__ int8_t sharedmem[];
|
||||
int8_t *selected_memory;
|
||||
|
||||
if constexpr (SMD == FULLSM) {
|
||||
selected_memory = sharedmem;
|
||||
} else {
|
||||
int block_index = blockIdx.x + blockIdx.y * gridDim.x +
|
||||
blockIdx.z * gridDim.x * gridDim.y;
|
||||
selected_memory = &device_mem[block_index * device_memory_size_per_block];
|
||||
}
|
||||
double2 *fft = (double2 *)selected_memory;
|
||||
|
||||
// Ids
|
||||
uint32_t level_id = blockIdx.z;
|
||||
@@ -98,7 +281,8 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
|
||||
get_start_ith_ggsw_offset(polynomial_size, glwe_dimension, level_count);
|
||||
|
||||
// Precalculate the monomial degrees and store them in shared memory
|
||||
uint32_t *monomial_degrees = (uint32_t *)selected_memory;
|
||||
// uint32_t *monomial_degrees = (uint32_t *)selected_memory;
|
||||
|
||||
if (threadIdx.x < (1 << grouping_factor)) {
|
||||
const Torus *lwe_array_group =
|
||||
block_lwe_array_in + rev_lwe_iteration * grouping_factor;
|
||||
@@ -117,23 +301,43 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
|
||||
polynomial_product_accumulate_by_monomial_nosync<Torus, params>(
|
||||
reg_acc, bsk_poly, monomial_degree);
|
||||
}
|
||||
synchronize_threads_in_block(); // needed because we are going to reuse the
|
||||
// shared memory for the fft
|
||||
|
||||
// synchronize_threads_in_block(); // needed because we are going to reuse
|
||||
// the shared memory for the fft
|
||||
// double2 *fft = (double2 *)selected_memory;
|
||||
// Move from local memory back to shared memory but as complex
|
||||
int tid = threadIdx.x;
|
||||
double2 *fft = (double2 *)selected_memory;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
fft[tid] =
|
||||
// int tid = threadIdx.x;
|
||||
// double2 *fft = (double2 *)selected_memory;
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < params::opt / 2; i++) {
|
||||
// fft[tid] =
|
||||
// make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
|
||||
// (double)std::numeric_limits<Torus>::max(),
|
||||
// __ll2double_rn((int64_t)reg_acc[i + params::opt /
|
||||
// 2]) /
|
||||
// (double)std::numeric_limits<Torus>::max());
|
||||
// tid += params::degree / params::opt;
|
||||
// }
|
||||
double2 u[params::opt >> 2];
|
||||
double2 v[params::opt >> 2];
|
||||
|
||||
for (int i = 0; i < params::opt / 4; i++) {
|
||||
u[i] =
|
||||
make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
|
||||
(double)std::numeric_limits<Torus>::max(),
|
||||
__ll2double_rn((int64_t)reg_acc[i + params::opt / 2]) /
|
||||
(double)std::numeric_limits<Torus>::max());
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
NSMFFT_direct<HalfDegree<params>>(fft);
|
||||
for (int i = 0; i < params::opt / 4; i++) {
|
||||
v[i] = make_double2(
|
||||
__ll2double_rn((int64_t)reg_acc[i + params::opt / 4]) /
|
||||
(double)std::numeric_limits<Torus>::max(),
|
||||
__ll2double_rn(
|
||||
(int64_t)reg_acc[i + params::opt / 2 + params::opt / 4]) /
|
||||
(double)std::numeric_limits<Torus>::max());
|
||||
}
|
||||
|
||||
NSMFFT_direct2<HalfDegree<params>>(fft, u, v);
|
||||
|
||||
// lwe iteration
|
||||
auto keybundle_out = get_ith_mask_kth_block(
|
||||
@@ -363,7 +567,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
|
||||
template <typename Torus>
|
||||
uint64_t get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle(
|
||||
uint32_t polynomial_size) {
|
||||
return sizeof(double2) * polynomial_size / 2; // accumulator
|
||||
return sizeof(double2) * polynomial_size; // / 2; // accumulator
|
||||
}
|
||||
template <typename Torus>
|
||||
uint64_t get_buffer_size_full_sm_multibit_programmable_bootstrap_step_one(
|
||||
@@ -513,8 +717,12 @@ __host__ void execute_compute_keybundle(
|
||||
auto keybundle_fft = buffer->keybundle_fft;
|
||||
|
||||
// Compute a keybundle
|
||||
dim3 grid_keybundle(num_samples * chunk_size,
|
||||
(glwe_dimension + 1) * (glwe_dimension + 1), level_count);
|
||||
// dim3 grid_keybundle(num_samples * chunk_size,
|
||||
// (glwe_dimension + 1) * (glwe_dimension + 1),
|
||||
// level_count);
|
||||
dim3 grid_keybundle(num_samples * chunk_size, (glwe_dimension + 1),
|
||||
level_count);
|
||||
|
||||
dim3 thds(polynomial_size / params::opt, 1, 1);
|
||||
|
||||
if (max_shared_memory < full_sm_keybundle)
|
||||
|
||||
@@ -17,6 +17,18 @@ __device__ void copy_polynomial(const T *__restrict__ source, T *dst) {
|
||||
tid = tid + block_size;
|
||||
}
|
||||
}
|
||||
template <typename T, int elems_per_thread, int block_size>
|
||||
__device__ void copy_polynomial_vec(const T *__restrict__ source, T *dst,
|
||||
const T *__restrict__ source2, T *dst2) {
|
||||
int tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
dst[tid] = source[tid];
|
||||
dst2[tid] = source2[tid];
|
||||
tid = tid + block_size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int elems_per_thread, int block_size>
|
||||
__device__ void copy_polynomial_in_regs(const T *__restrict__ source, T *dst) {
|
||||
#pragma unroll
|
||||
@@ -25,6 +37,17 @@ __device__ void copy_polynomial_in_regs(const T *__restrict__ source, T *dst) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int elems_per_thread, int block_size>
|
||||
__device__ void
|
||||
copy_polynomial_in_regs_vec(const T *__restrict__ source, T *dst,
|
||||
const T *__restrict__ source2, T *dst2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
dst[i] = source[threadIdx.x + i * block_size];
|
||||
dst2[i] = source2[threadIdx.x + i * block_size];
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Receives num_poly concatenated polynomials of type T. For each:
|
||||
*
|
||||
|
||||
@@ -130,4 +130,40 @@ __device__ void polynomial_product_accumulate_by_monomial_nosync(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, class params>
|
||||
__device__ void polynomial_product_accumulate_by_monomial_nosync_vec(
|
||||
T *result, T *result2, const T *__restrict__ poly,
|
||||
const T *__restrict__ poly2, uint32_t monomial_degree) {
|
||||
// monomial_degree \in [0, 2 * params::degree)
|
||||
int full_cycles_count = monomial_degree / params::degree;
|
||||
int remainder_degrees = monomial_degree % params::degree;
|
||||
|
||||
// Every thread has a fixed position to track instead of "chasing" the
|
||||
// position
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
int pos =
|
||||
(threadIdx.x + i * (params::degree / params::opt) - monomial_degree) &
|
||||
(params::degree - 1);
|
||||
|
||||
T element = poly[pos];
|
||||
T element2 = poly2[pos];
|
||||
T x = SEL(element, -element, full_cycles_count % 2);
|
||||
T x2 = SEL(element2, -element2, full_cycles_count % 2);
|
||||
bool condition =
|
||||
threadIdx.x + i * (params::degree / params::opt) >= remainder_degrees;
|
||||
x = SEL(-x, x, condition);
|
||||
x2 = SEL(-x2, x2, condition);
|
||||
// x = SEL(-x, x,
|
||||
// threadIdx.x + i * (params::degree / params::opt) >=
|
||||
// remainder_degrees);
|
||||
// x2 = SEL(-x2, x2,
|
||||
// threadIdx.x + i * (params::degree / params::opt) >=
|
||||
// remainder_degrees);
|
||||
|
||||
result[i] += x;
|
||||
result2[i] += x2;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // CNCRT_POLYNOMIAL_MATH_H
|
||||
|
||||
Reference in New Issue
Block a user