From 9b4faaa66ed36c40a0fd394d46af2a22ce90aa19 Mon Sep 17 00:00:00 2001 From: Beka Barbakadze Date: Thu, 16 Mar 2023 17:12:25 +0400 Subject: [PATCH] feat(concrete-cuda): unroll while loop for cuda fft and ifft --- .../implementation/src/fft/bnsmfft.cuh | 542 +++++++++++++++++- 1 file changed, 517 insertions(+), 25 deletions(-) diff --git a/backends/concrete-cuda/implementation/src/fft/bnsmfft.cuh b/backends/concrete-cuda/implementation/src/fft/bnsmfft.cuh index 9b68b8065..630857347 100644 --- a/backends/concrete-cuda/implementation/src/fft/bnsmfft.cuh +++ b/backends/concrete-cuda/implementation/src/fft/bnsmfft.cuh @@ -33,8 +33,6 @@ template __device__ void NSMFFT_direct(double2 *A) { size_t tid = threadIdx.x; size_t twid_id; - size_t t = params::degree / 2; - size_t m = 1; size_t i1, i2; double2 u, v, w; // level 1 @@ -44,7 +42,7 @@ template __device__ void NSMFFT_direct(double2 *A) { #pragma unroll for (size_t i = 0; i < params::opt / 2; ++i) { i1 = tid; - i2 = tid + t; + i2 = tid + params::degree / 2; u = A[i1]; v.x = (A[i2].x - A[i2].y) * 0.707106781186547461715008466854; v.y = (A[i2].x + A[i2].y) * 0.707106781186547461715008466854; @@ -57,22 +55,258 @@ template __device__ void NSMFFT_direct(double2 *A) { } __syncthreads(); - size_t iter = 1; - // for levels more than 1 - // from here none of the twiddles have equal real and imag part, so - // complete complex multiplication has to be done - // here we have more than one twiddles - while (t > 1) { - iter++; + // level 2 + // from this level there are more than one twiddles and none of them has equal + // real and imag parts, so complete complex multiplication is needed + // for each level params::degree / 2^level represents number of coefficients + // inside divided chunk of specific level + // + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 4); + i1 = 2 * (params::degree / 4) * twid_id + (tid & (params::degree / 4 - 1)); + i2 = i1 + params::degree / 4; + w = negtwiddles[twid_id + 2]; + u = A[i1]; + v.x = A[i2].x * w.x - A[i2].y * w.y; + v.y = A[i2].y * w.x + A[i2].x * w.y; + A[i1].x += v.x; + A[i1].y += v.y; + A[i2].x = u.x - v.x; + A[i2].y = u.y - v.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 3 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 8); + i1 = 2 * (params::degree / 8) * twid_id + (tid & (params::degree / 8 - 1)); + i2 = i1 + params::degree / 8; + w = negtwiddles[twid_id + 4]; + u = A[i1]; + v.x = A[i2].x * w.x - A[i2].y * w.y; + v.y = A[i2].y * w.x + A[i2].x * w.y; + A[i1].x += v.x; + A[i1].y += v.y; + A[i2].x = u.x - v.x; + A[i2].y = u.y - v.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 4 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 16); + i1 = + 2 * (params::degree / 16) * twid_id + (tid & (params::degree / 16 - 1)); + i2 = i1 + params::degree / 16; + w = negtwiddles[twid_id + 8]; + u = A[i1]; + v.x = A[i2].x * w.x - A[i2].y * w.y; + v.y = A[i2].y * w.x + A[i2].x * w.y; + A[i1].x += v.x; + A[i1].y += v.y; + A[i2].x = u.x - v.x; + A[i2].y = u.y - v.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 5 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 32); + i1 = + 2 * (params::degree / 32) * twid_id + (tid & (params::degree / 32 - 1)); + i2 = i1 + params::degree / 32; + w = negtwiddles[twid_id + 16]; + u = A[i1]; + v.x = A[i2].x * w.x - A[i2].y * w.y; + v.y = A[i2].y * w.x + A[i2].x * w.y; + A[i1].x += v.x; + A[i1].y += v.y; + A[i2].x = u.x - v.x; + A[i2].y = u.y - v.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 6 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 64); + i1 = + 2 * (params::degree / 64) * twid_id + (tid & (params::degree / 64 - 1)); + i2 = i1 + params::degree / 64; + w = negtwiddles[twid_id + 32]; + u = A[i1]; + v.x = A[i2].x * w.x - A[i2].y * w.y; + v.y = A[i2].y * w.x + A[i2].x * w.y; + A[i1].x += v.x; + A[i1].y += v.y; + A[i2].x = u.x - v.x; + A[i2].y = u.y - v.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 7 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 128); + i1 = 2 * (params::degree / 128) * twid_id + + (tid & (params::degree / 128 - 1)); + i2 = i1 + params::degree / 128; + w = negtwiddles[twid_id + 64]; + u = A[i1]; + v.x = A[i2].x * w.x - A[i2].y * w.y; + v.y = A[i2].y * w.x + A[i2].x * w.y; + A[i1].x += v.x; + A[i1].y += v.y; + A[i2].x = u.x - v.x; + A[i2].y = u.y - v.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // from level 8, we need to check size of params degree, because we support + // minimum actual polynomial size = 256, when compressed size is halfed and + // minimum supported compressed size is 128, so we always need first 7 + // levels of butterfy operation, since butterfly levels are hardcoded + // we need to check if polynomial size is big enough to require specific level + // of butterfly. + if constexpr (params::degree >= 256) { + // level 8 tid = threadIdx.x; - t >>= 1; - m <<= 1; #pragma unroll for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / t; - i1 = 2 * t * twid_id + (tid & (t - 1)); - i2 = i1 + t; - w = negtwiddles[twid_id + m]; + twid_id = tid / (params::degree / 256); + i1 = 2 * (params::degree / 256) * twid_id + + (tid & (params::degree / 256 - 1)); + i2 = i1 + params::degree / 256; + w = negtwiddles[twid_id + 128]; + u = A[i1]; + v.x = A[i2].x * w.x - A[i2].y * w.y; + v.y = A[i2].y * w.x + A[i2].x * w.y; + A[i1].x += v.x; + A[i1].y += v.y; + A[i2].x = u.x - v.x; + A[i2].y = u.y - v.y; + tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 512) { + // level 9 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 512); + i1 = 2 * (params::degree / 512) * twid_id + + (tid & (params::degree / 512 - 1)); + i2 = i1 + params::degree / 512; + w = negtwiddles[twid_id + 256]; + u = A[i1]; + v.x = A[i2].x * w.x - A[i2].y * w.y; + v.y = A[i2].y * w.x + A[i2].x * w.y; + A[i1].x += v.x; + A[i1].y += v.y; + A[i2].x = u.x - v.x; + A[i2].y = u.y - v.y; + tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 1024) { + // level 10 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 1024); + i1 = 2 * (params::degree / 1024) * twid_id + + (tid & (params::degree / 1024 - 1)); + i2 = i1 + params::degree / 1024; + w = negtwiddles[twid_id + 512]; + u = A[i1]; + v.x = A[i2].x * w.x - A[i2].y * w.y; + v.y = A[i2].y * w.x + A[i2].x * w.y; + A[i1].x += v.x; + A[i1].y += v.y; + A[i2].x = u.x - v.x; + A[i2].y = u.y - v.y; + tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 2048) { + // level 11 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 2048); + i1 = 2 * (params::degree / 2048) * twid_id + + (tid & (params::degree / 2048 - 1)); + i2 = i1 + params::degree / 2048; + w = negtwiddles[twid_id + 1024]; + u = A[i1]; + v.x = A[i2].x * w.x - A[i2].y * w.y; + v.y = A[i2].y * w.x + A[i2].x * w.y; + A[i1].x += v.x; + A[i1].y += v.y; + A[i2].x = u.x - v.x; + A[i2].y = u.y - v.y; + tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 4096) { + // level 12 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 4096); + i1 = 2 * (params::degree / 4096) * twid_id + + (tid & (params::degree / 4096 - 1)); + i2 = i1 + params::degree / 4096; + w = negtwiddles[twid_id + 2048]; + u = A[i1]; + v.x = A[i2].x * w.x - A[i2].y * w.y; + v.y = A[i2].y * w.x + A[i2].x * w.y; + A[i1].x += v.x; + A[i1].y += v.y; + A[i2].x = u.x - v.x; + A[i2].y = u.y - v.y; + tid += params::degree / params::opt; + } + __syncthreads(); + } + + // compressed size = 8192 is actual polynomial size = 16384. + // this size is not supported yet by any of the concrete-cuda api. + // may be used in the future. + if constexpr (params::degree >= 8192) { + // level 13 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 8192); + i1 = 2 * (params::degree / 8192) * twid_id + + (tid & (params::degree / 8192 - 1)); + i2 = i1 + params::degree / 8192; + w = negtwiddles[twid_id + 4096]; u = A[i1]; v.x = A[i2].x * w.x - A[i2].y * w.y; v.y = A[i2].y * w.x + A[i2].x * w.y; @@ -100,11 +334,10 @@ template __device__ void NSMFFT_inverse(double2 *A) { size_t tid = threadIdx.x; size_t twid_id; - size_t m = params::degree; - size_t t = 1; size_t i1, i2; double2 u, w; + // divide input by compressed polynomial size tid = threadIdx.x; for (size_t i = 0; i < params::opt; ++i) { A[tid].x *= 1. / params::degree; @@ -116,15 +349,22 @@ template __device__ void NSMFFT_inverse(double2 *A) { // none of the twiddles have equal real and imag part, so // complete complex multiplication has to be done // here we have more than one twiddle - while (m > 1) { + // mapping in backward fft is reversed + // butterfly operation is started from last level + + // compressed size = 8192 is actual polynomial size = 16384. + // this size is not supported yet by any of the concrete-cuda api. + // may be used in the future. + if constexpr (params::degree >= 8192) { + // level 13 tid = threadIdx.x; - m >>= 1; #pragma unroll for (size_t i = 0; i < params::opt / 2; ++i) { - twid_id = tid / t; - i1 = 2 * t * twid_id + (tid & (t - 1)); - i2 = i1 + t; - w = negtwiddles[twid_id + m]; + twid_id = tid / (params::degree / 8192); + i1 = 2 * (params::degree / 8192) * twid_id + + (tid & (params::degree / 8192 - 1)); + i2 = i1 + params::degree / 8192; + w = negtwiddles[twid_id + 4096]; u.x = A[i1].x - A[i2].x; u.y = A[i1].y - A[i2].y; A[i1].x += A[i2].x; @@ -134,9 +374,261 @@ template __device__ void NSMFFT_inverse(double2 *A) { A[i2].y = u.y * w.x - u.x * w.y; tid += params::degree / params::opt; } - t <<= 1; __syncthreads(); } + + if constexpr (params::degree >= 4096) { + // level 12 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 4096); + i1 = 2 * (params::degree / 4096) * twid_id + + (tid & (params::degree / 4096 - 1)); + i2 = i1 + params::degree / 4096; + w = negtwiddles[twid_id + 2048]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 2048) { + // level 11 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 2048); + i1 = 2 * (params::degree / 2048) * twid_id + + (tid & (params::degree / 2048 - 1)); + i2 = i1 + params::degree / 2048; + w = negtwiddles[twid_id + 1024]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 1024) { + // level 10 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 1024); + i1 = 2 * (params::degree / 1024) * twid_id + + (tid & (params::degree / 1024 - 1)); + i2 = i1 + params::degree / 1024; + w = negtwiddles[twid_id + 512]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 512) { + // level 9 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 512); + i1 = 2 * (params::degree / 512) * twid_id + + (tid & (params::degree / 512 - 1)); + i2 = i1 + params::degree / 512; + w = negtwiddles[twid_id + 256]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); + } + + if constexpr (params::degree >= 256) { + // level 8 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 256); + i1 = 2 * (params::degree / 256) * twid_id + + (tid & (params::degree / 256 - 1)); + i2 = i1 + params::degree / 256; + w = negtwiddles[twid_id + 128]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); + } + + // below level 8, we don't need to check size of params degree, because we + // support minimum actual polynomial size = 256, when compressed size is + // halfed and minimum supported compressed size is 128, so we always need + // last 7 levels of butterfy operation, since butterfly levels are hardcoded + // we don't need to check if polynomial size is big enough to require + // specific level of butterfly. + // level 7 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 128); + i1 = 2 * (params::degree / 128) * twid_id + + (tid & (params::degree / 128 - 1)); + i2 = i1 + params::degree / 128; + w = negtwiddles[twid_id + 64]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 6 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 64); + i1 = + 2 * (params::degree / 64) * twid_id + (tid & (params::degree / 64 - 1)); + i2 = i1 + params::degree / 64; + w = negtwiddles[twid_id + 32]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 5 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 32); + i1 = + 2 * (params::degree / 32) * twid_id + (tid & (params::degree / 32 - 1)); + i2 = i1 + params::degree / 32; + w = negtwiddles[twid_id + 16]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 4 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 16); + i1 = + 2 * (params::degree / 16) * twid_id + (tid & (params::degree / 16 - 1)); + i2 = i1 + params::degree / 16; + w = negtwiddles[twid_id + 8]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 3 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 8); + i1 = 2 * (params::degree / 8) * twid_id + (tid & (params::degree / 8 - 1)); + i2 = i1 + params::degree / 8; + w = negtwiddles[twid_id + 4]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 2 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 4); + i1 = 2 * (params::degree / 4) * twid_id + (tid & (params::degree / 4 - 1)); + i2 = i1 + params::degree / 4; + w = negtwiddles[twid_id + 2]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); + + // level 1 + tid = threadIdx.x; +#pragma unroll + for (size_t i = 0; i < params::opt / 2; ++i) { + twid_id = tid / (params::degree / 2); + i1 = 2 * (params::degree / 2) * twid_id + (tid & (params::degree / 2 - 1)); + i2 = i1 + params::degree / 2; + w = negtwiddles[twid_id + 1]; + u.x = A[i1].x - A[i2].x; + u.y = A[i1].y - A[i2].y; + A[i1].x += A[i2].x; + A[i1].y += A[i2].y; + + A[i2].x = u.x * w.x + u.y * w.y; + A[i2].y = u.y * w.x - u.x * w.y; + tid += params::degree / params::opt; + } + __syncthreads(); } /*