mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
feat(gpu): perform fft128 in warps when possible
This commit is contained in:
@@ -259,6 +259,16 @@ struct f128x2 {
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
// Shuffle XOR for f128x2
|
||||
__device__ inline f128x2 shfl_xor_f128x2(f128x2 val, int laneMask,
|
||||
unsigned mask = 0xFFFFFFFF) {
|
||||
double rehi = __shfl_xor_sync(mask, val.re.hi, laneMask);
|
||||
double relo = __shfl_xor_sync(mask, val.re.lo, laneMask);
|
||||
double imhi = __shfl_xor_sync(mask, val.im.hi, laneMask);
|
||||
double imlo = __shfl_xor_sync(mask, val.im.lo, laneMask);
|
||||
|
||||
return f128x2(f128(rehi, relo), f128(imhi, imlo));
|
||||
}
|
||||
|
||||
__host__ __device__ inline uint64_t double_to_bits(double d) {
|
||||
uint64_t bits = *reinterpret_cast<uint64_t *>(&d);
|
||||
|
||||
@@ -64,7 +64,7 @@ __device__ void negacyclic_forward_fft_f128(double *dt_re_hi, double *dt_re_lo,
|
||||
}
|
||||
|
||||
Index twiddle_shift = 1;
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 5; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
@@ -98,6 +98,39 @@ __device__ void negacyclic_forward_fft_f128(double *dt_re_hi, double *dt_re_lo,
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// For here on we can do everything in warps
|
||||
|
||||
for (Index l = 5; l >= 1; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
f128x2 warp_mem[BUTTERFLY_DEPTH];
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
warp_mem[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = shfl_xor_f128x2(warp_mem[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
w = NEG_TWID(tid / lane_mask + twiddle_shift);
|
||||
f128::cplx_f128_mul_assign(w.re, w.im, v[i].re, v[i].im, w.re, w.im);
|
||||
f128::cplx_f128_sub_assign(v[i].re, v[i].im, u[i].re, u[i].im, w.re,
|
||||
w.im);
|
||||
f128::cplx_f128_add_assign(u[i].re, u[i].im, u[i].re, u[i].im, w.re,
|
||||
w.im);
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
// store registers in SM
|
||||
tid = threadIdx.x;
|
||||
@@ -107,7 +140,6 @@ __device__ void negacyclic_forward_fft_f128(double *dt_re_hi, double *dt_re_lo,
|
||||
F128x2_TO_F64x4(v[i], (tid * 2 + 1));
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <class params>
|
||||
@@ -131,9 +163,44 @@ __device__ void negacyclic_backward_fft_f128(double *dt_re_hi, double *dt_re_lo,
|
||||
F64x4_TO_F128x2(v[i], 2 * tid + 1);
|
||||
tid += STRIDE;
|
||||
}
|
||||
|
||||
// First iterations can be solve within the warps
|
||||
Index twiddle_shift = DEGREE;
|
||||
for (Index l = 1; l <= LOG2_DEGREE - 1; ++l) {
|
||||
for (Index l = 1; l <= 5; ++l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
tid = threadIdx.x;
|
||||
twiddle_shift >>= 1;
|
||||
f128x2 warp_mem[BUTTERFLY_DEPTH];
|
||||
// at this point registers are ready for the butterfly
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = (u[i] - v[i]);
|
||||
u[i] += v[i];
|
||||
v[i] = w * NEG_TWID(tid / lane_mask + twiddle_shift).conjugate();
|
||||
|
||||
// keep one of the register for next iteration and store another one in sm
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
warp_mem[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
|
||||
// prepare registers for next butterfly iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = shfl_xor_f128x2(warp_mem[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
for (Index l = 5; l <= LOG2_DEGREE - 1; ++l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
tid = threadIdx.x;
|
||||
|
||||
Reference in New Issue
Block a user