feat(gpu): perform fft128 in warps when possible

This commit is contained in:
Guillermo Oyarzun
2025-08-19 11:58:26 +02:00
parent 1647ec8f21
commit 774f51fd5b
2 changed files with 81 additions and 4 deletions

View File

@@ -259,6 +259,16 @@ struct f128x2 {
return *this;
}
};
// Shuffle XOR for f128x2
__device__ inline f128x2 shfl_xor_f128x2(f128x2 val, int laneMask,
unsigned mask = 0xFFFFFFFF) {
double rehi = __shfl_xor_sync(mask, val.re.hi, laneMask);
double relo = __shfl_xor_sync(mask, val.re.lo, laneMask);
double imhi = __shfl_xor_sync(mask, val.im.hi, laneMask);
double imlo = __shfl_xor_sync(mask, val.im.lo, laneMask);
return f128x2(f128(rehi, relo), f128(imhi, imlo));
}
__host__ __device__ inline uint64_t double_to_bits(double d) {
uint64_t bits = *reinterpret_cast<uint64_t *>(&d);

View File

@@ -64,7 +64,7 @@ __device__ void negacyclic_forward_fft_f128(double *dt_re_hi, double *dt_re_lo,
}
Index twiddle_shift = 1;
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
for (Index l = LOG2_DEGREE - 1; l >= 5; --l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
twiddle_shift <<= 1;
@@ -98,6 +98,39 @@ __device__ void negacyclic_forward_fft_f128(double *dt_re_hi, double *dt_re_lo,
}
}
__syncthreads();
// For here on we can do everything in warps
for (Index l = 5; l >= 1; --l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
twiddle_shift <<= 1;
f128x2 warp_mem[BUTTERFLY_DEPTH];
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
warp_mem[i] = (u_stays_in_register) ? v[i] : u[i];
tid = tid + STRIDE;
}
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
w = shfl_xor_f128x2(warp_mem[i], 1 << (l - 1), 0xFFFFFFFF);
u[i] = (u_stays_in_register) ? u[i] : w;
v[i] = (u_stays_in_register) ? w : v[i];
w = NEG_TWID(tid / lane_mask + twiddle_shift);
f128::cplx_f128_mul_assign(w.re, w.im, v[i].re, v[i].im, w.re, w.im);
f128::cplx_f128_sub_assign(v[i].re, v[i].im, u[i].re, u[i].im, w.re,
w.im);
f128::cplx_f128_add_assign(u[i].re, u[i].im, u[i].re, u[i].im, w.re,
w.im);
tid = tid + STRIDE;
}
}
// store registers in SM
tid = threadIdx.x;
@@ -107,7 +140,6 @@ __device__ void negacyclic_forward_fft_f128(double *dt_re_hi, double *dt_re_lo,
F128x2_TO_F64x4(v[i], (tid * 2 + 1));
tid = tid + STRIDE;
}
__syncthreads();
}
template <class params>
@@ -131,9 +163,44 @@ __device__ void negacyclic_backward_fft_f128(double *dt_re_hi, double *dt_re_lo,
F64x4_TO_F128x2(v[i], 2 * tid + 1);
tid += STRIDE;
}
// First iterations can be solve within the warps
Index twiddle_shift = DEGREE;
for (Index l = 1; l <= LOG2_DEGREE - 1; ++l) {
for (Index l = 1; l <= 5; ++l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
tid = threadIdx.x;
twiddle_shift >>= 1;
f128x2 warp_mem[BUTTERFLY_DEPTH];
// at this point registers are ready for the butterfly
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
w = (u[i] - v[i]);
u[i] += v[i];
v[i] = w * NEG_TWID(tid / lane_mask + twiddle_shift).conjugate();
// keep one of the register for next iteration and store another one in sm
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
warp_mem[i] = (u_stays_in_register) ? v[i] : u[i];
tid = tid + STRIDE;
}
// prepare registers for next butterfly iteration
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
w = shfl_xor_f128x2(warp_mem[i], 1 << (l - 1), 0xFFFFFFFF);
u[i] = (u_stays_in_register) ? u[i] : w;
v[i] = (u_stays_in_register) ? w : v[i];
tid = tid + STRIDE;
}
}
for (Index l = 5; l <= LOG2_DEGREE - 1; ++l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
tid = threadIdx.x;