mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 15:48:20 -05:00
Compare commits
9 Commits
create-pul
...
feat/gpu/p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f085c7ea3 | ||
|
|
c4232ded05 | ||
|
|
d7d0613c4c | ||
|
|
03b4522ddf | ||
|
|
de87056e31 | ||
|
|
2c4aa9bf3b | ||
|
|
24c131b590 | ||
|
|
199bbb3fd8 | ||
|
|
317a9c5709 |
@@ -62,6 +62,7 @@ fn main() {
|
||||
"cuda/include/integer/integer.h",
|
||||
"cuda/include/keyswitch.h",
|
||||
"cuda/include/linear_algebra.h",
|
||||
"cuda/include/pbs/fft.h",
|
||||
"cuda/include/pbs/programmable_bootstrap.h",
|
||||
"cuda/include/pbs/programmable_bootstrap_multibit.h",
|
||||
];
|
||||
|
||||
6
backends/tfhe-cuda-backend/cuda/include/pbs/fft.h
Normal file
6
backends/tfhe-cuda-backend/cuda/include/pbs/fft.h
Normal file
@@ -0,0 +1,6 @@
|
||||
#include <stdint.h>
|
||||
extern "C" {
|
||||
void fourier_transform_forward_f128(void *stream, uint32_t gpu_index, void *re0,
|
||||
void *re1, void *im0, void *im1,
|
||||
void const *standard, uint32_t const N);
|
||||
}
|
||||
316
backends/tfhe-cuda-backend/cuda/src/fft128/f128.cuh
Normal file
316
backends/tfhe-cuda-backend/cuda/src/fft128/f128.cuh
Normal file
@@ -0,0 +1,316 @@
|
||||
|
||||
#ifndef TFHE_RS_BACKENDS_TFHE_CUDA_BACKEND_CUDA_SRC_FFT128_F128_CUH_
|
||||
#define TFHE_RS_BACKENDS_TFHE_CUDA_BACKEND_CUDA_SRC_FFT128_F128_CUH_
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
struct alignas(16) f128 {
|
||||
double hi;
|
||||
double lo;
|
||||
|
||||
// Default and parameterized constructors
|
||||
__host__ __device__ f128() : hi(0.0), lo(0.0) {}
|
||||
__host__ __device__ f128(double high, double low) : hi(high), lo(low) {}
|
||||
|
||||
// Quick two-sum
|
||||
__host__ __device__ __forceinline__ static f128 quick_two_sum(double a,
|
||||
double b) {
|
||||
double s = a + b;
|
||||
return f128(s, b - (s - a));
|
||||
}
|
||||
|
||||
// Two-sum
|
||||
__host__ __device__ __forceinline__ static f128 two_sum(double a, double b) {
|
||||
double s = a + b;
|
||||
double bb = s - a;
|
||||
return f128(s, (a - (s - bb)) + (b - bb));
|
||||
}
|
||||
|
||||
// Two-product
|
||||
__host__ __device__ __forceinline__ static f128 two_prod(double a, double b) {
|
||||
double p = a * b;
|
||||
#ifdef __CUDA_ARCH__
|
||||
return f128(p, __fma_rn(a, b, -p));
|
||||
#else
|
||||
return f128(p, fma(a, b, -p));
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ __forceinline__ static f128 two_diff(double a, double b) {
|
||||
double s = a - b;
|
||||
double bb = s - a;
|
||||
return f128(s, (a - (s - bb)) - (b + bb));
|
||||
}
|
||||
|
||||
// Addition
|
||||
__host__ __device__ static f128 add(const f128 &a, const f128 &b) {
|
||||
auto s = two_sum(a.hi, b.hi);
|
||||
auto t = two_sum(a.lo, b.lo);
|
||||
|
||||
double hi = s.hi;
|
||||
double lo = s.lo + t.hi;
|
||||
hi = hi + lo;
|
||||
lo = lo - (hi - s.hi);
|
||||
|
||||
return f128(hi, lo + t.lo);
|
||||
}
|
||||
|
||||
// Addition with estimate
|
||||
__host__ __device__ static f128 add_estimate(const f128 &a, const f128 &b) {
|
||||
auto se = two_sum(a.hi, b.hi);
|
||||
double hi = se.hi;
|
||||
double lo = se.lo + a.lo + b.lo;
|
||||
|
||||
hi = hi + lo;
|
||||
lo = lo - (hi - se.hi);
|
||||
|
||||
return f128(hi, lo);
|
||||
}
|
||||
|
||||
// Subtraction with estimate
|
||||
__host__ __device__ static f128 sub_estimate(const f128 &a, const f128 &b) {
|
||||
f128 se = two_diff(a.hi, b.hi);
|
||||
se.lo += a.lo;
|
||||
se.lo -= b.lo;
|
||||
return quick_two_sum(se.hi, se.lo);
|
||||
}
|
||||
|
||||
// Subtraction
|
||||
__host__ __device__ static f128 sub(const f128 &a, const f128 &b) {
|
||||
auto s = two_diff(a.hi, b.hi);
|
||||
auto t = two_diff(a.lo, b.lo);
|
||||
s = quick_two_sum(s.hi, s.lo + t.hi);
|
||||
return quick_two_sum(s.hi, s.lo + t.lo);
|
||||
}
|
||||
|
||||
// Multiplication
|
||||
__host__ __device__ static f128 mul(const f128 &a, const f128 &b) {
|
||||
double hi, lo;
|
||||
auto p = two_prod(a.hi, b.hi);
|
||||
hi = p.hi;
|
||||
lo = p.lo + (a.hi * b.lo + a.lo * b.hi);
|
||||
|
||||
hi = hi + lo;
|
||||
lo = lo - (hi - p.hi);
|
||||
|
||||
return f128(hi, lo);
|
||||
}
|
||||
|
||||
__host__ __device__ static void
|
||||
cplx_f128_mul_assign(f128 &c_re, f128 &c_im, const f128 &a_re,
|
||||
const f128 &a_im, const f128 &b_re, const f128 &b_im) {
|
||||
auto a_re_x_b_re = mul(a_re, b_re);
|
||||
auto a_re_x_b_im = mul(a_re, b_im);
|
||||
auto a_im_x_b_re = mul(a_im, b_re);
|
||||
auto a_im_x_b_im = mul(a_im, b_im);
|
||||
|
||||
c_re = add_estimate(a_re_x_b_re, a_im_x_b_im);
|
||||
c_im = sub_estimate(a_im_x_b_re, a_re_x_b_im);
|
||||
}
|
||||
};
|
||||
|
||||
struct f128x2 {
|
||||
f128 re;
|
||||
f128 im;
|
||||
|
||||
__host__ __device__ f128x2() : re(), im() {}
|
||||
|
||||
__host__ __device__ f128x2(const f128 &real, const f128 &imag)
|
||||
: re(real), im(imag) {}
|
||||
|
||||
__host__ __device__ f128x2(double real, double imag)
|
||||
: re(real, 0.0), im(imag, 0.0) {}
|
||||
|
||||
__host__ __device__ explicit f128x2(double real)
|
||||
: re(real, 0.0), im(0.0, 0.0) {}
|
||||
|
||||
__host__ __device__ f128x2(const f128x2 &other)
|
||||
: re(other.re), im(other.im) {}
|
||||
|
||||
__host__ __device__ f128x2(f128x2 &&other) noexcept
|
||||
: re(std::move(other.re)), im(std::move(other.im)) {}
|
||||
|
||||
__host__ __device__ f128x2 &operator=(const f128x2 &other) {
|
||||
if (this != &other) {
|
||||
re = other.re;
|
||||
im = other.im;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ f128x2 &operator=(f128x2 &&other) noexcept {
|
||||
if (this != &other) {
|
||||
re = std::move(other.re);
|
||||
im = std::move(other.im);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ f128x2 conjugate() const {
|
||||
return f128x2(re, f128(-im.hi, -im.lo));
|
||||
}
|
||||
|
||||
__host__ __device__ f128 norm_squared() const {
|
||||
return f128::add(f128::mul(re, re), f128::mul(im, im));
|
||||
}
|
||||
|
||||
__host__ __device__ void zero() {
|
||||
re = f128(0.0, 0.0);
|
||||
im = f128(0.0, 0.0);
|
||||
}
|
||||
|
||||
// Addition
|
||||
__host__ __device__ friend f128x2 operator+(const f128x2 &a,
|
||||
const f128x2 &b) {
|
||||
return f128x2(f128::add(a.re, b.re), f128::add(a.im, b.im));
|
||||
}
|
||||
|
||||
// Subtraction
|
||||
__host__ __device__ friend f128x2 operator-(const f128x2 &a,
|
||||
const f128x2 &b) {
|
||||
return f128x2(f128::add(a.re, f128(-b.re.hi, -b.re.lo)),
|
||||
f128::add(a.im, f128(-b.im.hi, -b.im.lo)));
|
||||
}
|
||||
|
||||
// Multiplication (complex multiplication)
|
||||
__host__ __device__ friend f128x2 operator*(const f128x2 &a,
|
||||
const f128x2 &b) {
|
||||
f128 real_part =
|
||||
f128::add(f128::mul(a.re, b.re),
|
||||
f128(-f128::mul(a.im, b.im).hi, -f128::mul(a.im, b.im).lo));
|
||||
f128 imag_part = f128::add(f128::mul(a.re, b.im), f128::mul(a.im, b.re));
|
||||
return f128x2(real_part, imag_part);
|
||||
}
|
||||
|
||||
// Addition-assignment operator
|
||||
__host__ __device__ f128x2 &operator+=(const f128x2 &other) {
|
||||
re = f128::add(re, other.re);
|
||||
im = f128::add(im, other.im);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Subtraction-assignment operator
|
||||
__host__ __device__ f128x2 &operator-=(const f128x2 &other) {
|
||||
re = f128::add(re, f128(-other.re.hi, -other.re.lo));
|
||||
im = f128::add(im, f128(-other.im.hi, -other.im.lo));
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Multiplication-assignment operator
|
||||
__host__ __device__ f128x2 &operator*=(const f128x2 &other) {
|
||||
f128 new_re =
|
||||
f128::add(f128::mul(re, other.re), f128(-f128::mul(im, other.im).hi,
|
||||
-f128::mul(im, other.im).lo));
|
||||
f128 new_im = f128::add(f128::mul(re, other.im), f128::mul(im, other.re));
|
||||
re = new_re;
|
||||
im = new_im;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ inline uint64_t double_to_bits(double d) {
|
||||
uint64_t bits;
|
||||
std::memcpy(&bits, &d, sizeof(bits));
|
||||
return bits;
|
||||
}
|
||||
|
||||
__host__ __device__ inline double bits_to_double(uint64_t bits)
|
||||
{
|
||||
double d;
|
||||
std::memcpy(&d, &bits, sizeof(d));
|
||||
return d;
|
||||
}
|
||||
|
||||
|
||||
__host__ __device__ double u128_to_f64(__uint128_t x) {
|
||||
const __uint128_t ONE = 1;
|
||||
const double A = ONE << 52;
|
||||
const double B = ONE << 104;
|
||||
const double C = ONE << 76;
|
||||
const double D = 340282366920938500000000000000000000000.;
|
||||
|
||||
const __uint128_t threshold = (ONE << 104);
|
||||
|
||||
if (x < threshold) {
|
||||
uint64_t A_bits = double_to_bits(A);
|
||||
|
||||
__uint128_t shifted = (x << 12);
|
||||
uint64_t lower64 = static_cast<uint64_t>(shifted);
|
||||
lower64 >>= 12;
|
||||
|
||||
uint64_t bits_l = A_bits | lower64;
|
||||
double l_temp = bits_to_double(bits_l);
|
||||
double l = l_temp - A;
|
||||
|
||||
uint64_t B_bits = double_to_bits(B);
|
||||
uint64_t top64 = static_cast<uint64_t>(x >> 52);
|
||||
uint64_t bits_h = B_bits | top64;
|
||||
double h_temp = bits_to_double(bits_h);
|
||||
double h = h_temp - B;
|
||||
|
||||
return (l + h);
|
||||
|
||||
} else {
|
||||
uint64_t C_bits = double_to_bits(C);
|
||||
|
||||
__uint128_t shifted = (x >> 12);
|
||||
uint64_t lower64 = static_cast<uint64_t>(shifted);
|
||||
lower64 >>= 12;
|
||||
|
||||
uint64_t x_lo = static_cast<uint64_t>(x);
|
||||
uint64_t mask_part = (x_lo & 0xFFFFFFULL);
|
||||
|
||||
uint64_t bits_l = C_bits | lower64 | mask_part;
|
||||
double l_temp = bits_to_double(bits_l);
|
||||
double l = l_temp - C;
|
||||
|
||||
uint64_t D_bits = double_to_bits(D);
|
||||
uint64_t top64 = static_cast<uint64_t>(x >> 76);
|
||||
uint64_t bits_h = D_bits | top64;
|
||||
double h_temp = bits_to_double(bits_h);
|
||||
double h = h_temp - D;
|
||||
|
||||
return (l + h);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ __uint128_t f64_to_u128(const double f) {
|
||||
const __uint128_t ONE = 1;
|
||||
const uint64_t f_bits = double_to_bits(f);
|
||||
if (f_bits < 1023ull << 52) {
|
||||
return 0;
|
||||
} else {
|
||||
const __uint128_t m = ONE << 127 | (__uint128_t) f_bits << 75;
|
||||
const uint64_t s = 1150 - (f_bits >> 52);
|
||||
if (s >= 128) {
|
||||
return 0;
|
||||
} else {
|
||||
return m >> s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ double i128_to_f64(__int128_t const x) {
|
||||
uint64_t sign = static_cast<uint64_t>(x >> 64) & (1ULL << 63);
|
||||
__uint128_t abs = (x < 0)
|
||||
? static_cast<__uint128_t>(-x)
|
||||
: static_cast<__uint128_t>(x);
|
||||
|
||||
return bits_to_double(double_to_bits(u128_to_f64(abs)) | sign);
|
||||
|
||||
}
|
||||
__host__ __device__ f128 u128_to_signed_to_f128(__uint128_t x) {
|
||||
const double first_approx = i128_to_f64(x);
|
||||
const uint64_t sign_bit = double_to_bits(first_approx) * (1ull << 64);
|
||||
const __uint128_t first_approx_roundtrip =
|
||||
f64_to_u128((first_approx < 0) ? -first_approx : first_approx);
|
||||
const __uint128_t first_approx_roundtrip_signed = (sign_bit == (1ull << 63))
|
||||
?-first_approx_roundtrip
|
||||
:first_approx_roundtrip;
|
||||
|
||||
double correction = i128_to_f64(x - first_approx_roundtrip_signed);
|
||||
|
||||
return f128(first_approx, correction);
|
||||
};
|
||||
#endif
|
||||
42
backends/tfhe-cuda-backend/cuda/src/fft128/fft128.cu
Normal file
42
backends/tfhe-cuda-backend/cuda/src/fft128/fft128.cu
Normal file
@@ -0,0 +1,42 @@
|
||||
#include "fft128.cuh"
|
||||
|
||||
void fourier_transform_forward_f128(void *stream, uint32_t gpu_index, void *re0,
|
||||
void *re1, void *im0, void *im1,
|
||||
void const *standard, uint32_t const N) {
|
||||
switch (N) {
|
||||
case 256:
|
||||
host_fourier_transform_forward_f128_split_input<Degree<256>>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index, (double *)re0,
|
||||
(double *)re1, (double *)im1, (double *)im1,
|
||||
(__uint128_t const *)standard, N);
|
||||
break;
|
||||
case 512:
|
||||
host_fourier_transform_forward_f128_split_input<Degree<512>>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index, (double *)re0,
|
||||
(double *)re1, (double *)im1, (double *)im1,
|
||||
(__uint128_t const *)standard, N);
|
||||
break;
|
||||
case 1024:
|
||||
host_fourier_transform_forward_f128_split_input<Degree<1024>>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index, (double *)re0,
|
||||
(double *)re1, (double *)im1, (double *)im1,
|
||||
(__uint128_t const *)standard, N);
|
||||
break;
|
||||
case 2048:
|
||||
host_fourier_transform_forward_f128_split_input<Degree<2048>>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index, (double *)re0,
|
||||
(double *)re1, (double *)im1, (double *)im1,
|
||||
(__uint128_t const *)standard, N);
|
||||
break;
|
||||
case 4096:
|
||||
host_fourier_transform_forward_f128_split_input<Degree<4096>>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index, (double *)re0,
|
||||
(double *)re1, (double *)im1, (double *)im1,
|
||||
(__uint128_t const *)standard, N);
|
||||
break;
|
||||
default:
|
||||
PANIC("Cuda error (f128 fft): unsupported polynomial size. Supported "
|
||||
"N's are powers of two"
|
||||
" in the interval [256..4096].")
|
||||
}
|
||||
}
|
||||
330
backends/tfhe-cuda-backend/cuda/src/fft128/fft128.cuh
Normal file
330
backends/tfhe-cuda-backend/cuda/src/fft128/fft128.cuh
Normal file
@@ -0,0 +1,330 @@
|
||||
#ifndef TFHE_RS_BACKENDS_TFHE_CUDA_BACKEND_CUDA_SRC_FFT128_FFT128_CUH_
|
||||
#define TFHE_RS_BACKENDS_TFHE_CUDA_BACKEND_CUDA_SRC_FFT128_FFT128_CUH_
|
||||
|
||||
#include "f128.cuh"
|
||||
#include "pbs/fft.h"
|
||||
#include "polynomial/functions.cuh"
|
||||
#include "polynomial/parameters.cuh"
|
||||
#include "twiddles.cuh"
|
||||
#include "types/complex/operations.cuh"
|
||||
#include <iostream>
|
||||
|
||||
using Index = unsigned;
|
||||
|
||||
#define NEG_TWID(i) \
|
||||
f128x2(f128(neg_twiddles_re_hi[(i)], neg_twiddles_re_lo[(i)]), \
|
||||
f128(neg_twiddles_im_hi[(i)], neg_twiddles_im_lo[(i)]))
|
||||
|
||||
#define F64x4_TO_F128x2(f128x2_reg, ind) \
|
||||
f128x2_reg.re.hi = dt_re_hi[ind]; \
|
||||
f128x2_reg.re.lo = dt_re_lo[ind]; \
|
||||
f128x2_reg.im.hi = dt_im_hi[ind]; \
|
||||
f128x2_reg.im.lo = dt_im_lo[ind];
|
||||
|
||||
#define F128x2_TO_F64x4(f128x2_reg, ind) \
|
||||
dt_re_hi[ind] = f128x2_reg.re.hi; \
|
||||
dt_re_lo[ind] = f128x2_reg.re.lo; \
|
||||
dt_im_hi[ind] = f128x2_reg.im.hi; \
|
||||
dt_im_lo[ind] = f128x2_reg.im.lo;
|
||||
|
||||
// zl - left part of butterfly operation
|
||||
// zr - right part of butterfly operation
|
||||
// re - real part
|
||||
// im - imaginary part
|
||||
// hi - high bits
|
||||
// lo - low bits
|
||||
// dt - list
|
||||
// cf - single coefficient
|
||||
template <class params>
|
||||
__device__ void negacyclic_forward_fft_f128(double *dt_re_hi, double *dt_re_lo,
|
||||
double *dt_im_hi,
|
||||
double *dt_im_lo) {
|
||||
|
||||
__syncthreads();
|
||||
constexpr Index BUTTERFLY_DEPTH = params::opt >> 1;
|
||||
constexpr Index LOG2_DEGREE = params::log2_degree;
|
||||
constexpr Index HALF_DEGREE = params::degree >> 1;
|
||||
constexpr Index STRIDE = params::degree / params::opt;
|
||||
|
||||
f128x2 u[BUTTERFLY_DEPTH], v[BUTTERFLY_DEPTH], w;
|
||||
|
||||
Index tid = threadIdx.x;
|
||||
|
||||
// load into registers
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
F64x4_TO_F128x2(u[i], tid);
|
||||
F64x4_TO_F128x2(v[i], tid + HALF_DEGREE);
|
||||
tid += STRIDE;
|
||||
}
|
||||
|
||||
// level 1
|
||||
// we don't make actual complex multiplication on level1 since we have only
|
||||
// one twiddle, it's real and image parts are equal, so we can multiply
|
||||
// it with simpler operations
|
||||
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = v[i] * NEG_TWID(1);
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
}
|
||||
|
||||
Index twiddle_shift = 1;
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
|
||||
tid = threadIdx.x;
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
F128x2_TO_F64x4((u_stays_in_register) ? v[i] : u[i], tid);
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
F64x4_TO_F128x2(w, tid ^ lane_mask);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
w = NEG_TWID(tid / lane_mask + twiddle_shift);
|
||||
|
||||
w *= v[i];
|
||||
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// store registers in SM
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
F128x2_TO_F64x4(u[i], tid * 2);
|
||||
F128x2_TO_F64x4(v[i], tid * 2 + 1);
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <class params>
|
||||
__device__ void negacyclic_inverse_fft_f128(double *dt_re_hi, double *dt_re_lo,
|
||||
double *dt_im_hi,
|
||||
double *dt_im_lo) {
|
||||
__syncthreads();
|
||||
constexpr Index BUTTERFLY_DEPTH = params::opt >> 1;
|
||||
constexpr Index LOG2_DEGREE = params::log2_degree;
|
||||
constexpr Index DEGREE = params::degree;
|
||||
constexpr Index HALF_DEGREE = params::degree >> 1;
|
||||
constexpr Index STRIDE = params::degree / params::opt;
|
||||
|
||||
size_t tid = threadIdx.x;
|
||||
f128x2 u[BUTTERFLY_DEPTH], v[BUTTERFLY_DEPTH], w;
|
||||
|
||||
// load into registers and divide by compressed polynomial size
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
|
||||
F64x4_TO_F128x2(u[i], 2 * tid);
|
||||
F64x4_TO_F128x2(v[i], 2 * tid + 1);
|
||||
|
||||
// TODO f128 / double and f128x2/double
|
||||
// u[i] /= DEGREE;
|
||||
// v[i] /= DEGREE;
|
||||
|
||||
tid += STRIDE;
|
||||
}
|
||||
|
||||
Index twiddle_shift = DEGREE;
|
||||
for (Index l = 1; l <= LOG2_DEGREE - 1; ++l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
tid = threadIdx.x;
|
||||
twiddle_shift >>= 1;
|
||||
|
||||
// at this point registers are ready for the butterfly
|
||||
tid = threadIdx.x;
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = (u[i] - v[i]);
|
||||
u[i] += v[i];
|
||||
v[i] = w * NEG_TWID(tid / lane_mask + twiddle_shift).conjugate();
|
||||
|
||||
// keep one of the register for next iteration and store another one in sm
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
F128x2_TO_F64x4((u_stays_in_register) ? v[i] : u[i], tid);
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// prepare registers for next butterfly iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
F64x4_TO_F128x2(w, tid ^ lane_mask);
|
||||
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
// last iteration
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = (u[i] - v[i]);
|
||||
u[i] = u[i] + v[i];
|
||||
v[i] = w * NEG_TWID(1).conjugate();
|
||||
}
|
||||
__syncthreads();
|
||||
// store registers in SM
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
F128x2_TO_F64x4(u[i], tid);
|
||||
F128x2_TO_F64x4(v[i], tid + HALF_DEGREE);
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// params is expected to be full degree not half degree
|
||||
template <class params>
|
||||
__device__ void convert_u128_to_f128(
|
||||
double *out_re_hi, double *out_re_lo,
|
||||
double *out_im_hi, double *out_im_lo,
|
||||
const __uint128_t *in_re, const __uint128_t *in_im) {
|
||||
|
||||
Index tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < params::opt / 2; i++) {
|
||||
auto out_re = u128_to_signed_to_f128(in_re[tid]);
|
||||
auto out_im = u128_to_signed_to_f128(in_im[tid]);
|
||||
|
||||
out_re_hi[tid] = out_re.hi;
|
||||
out_re_lo[tid] = out_re.lo;
|
||||
out_im_hi[tid] = out_im.hi;
|
||||
out_im_lo[tid] = out_im.lo;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// params is expected to be full degree not half degree
|
||||
template <class params>
|
||||
__global__ void batch_convert_u128_to_f128(
|
||||
double *out_re_hi, double *out_re_lo,
|
||||
double *out_im_hi, double *out_im_lo,
|
||||
const __uint128_t *in) {
|
||||
|
||||
convert_u128_to_f128<params>(
|
||||
&out_re_hi[blockIdx.x * params::degree / 2],
|
||||
&out_re_lo[blockIdx.x * params::degree / 2],
|
||||
&out_im_hi[blockIdx.x * params::degree / 2],
|
||||
&out_im_lo[blockIdx.x * params::degree / 2],
|
||||
&in[blockIdx.x * params::degree],
|
||||
&in[blockIdx.x * params::degree + params::degree / 2]);
|
||||
}
|
||||
|
||||
template <class params, sharedMemDegree SMD>
|
||||
__global__ void batch_NSMFFT_128(double *in_re_hi, double *in_re_lo,
|
||||
double *in_im_hi,
|
||||
double *in_im_lo,
|
||||
double *out_re_hi, double *out_re_lo,
|
||||
double *out_im_hi,
|
||||
double *out_im_lo,
|
||||
double *buffer) {
|
||||
extern __shared__ double sharedMemoryFFT[];
|
||||
double2 *re_hi, *re_lo, *im_hi, *im_lo;
|
||||
if (SMD == NOSM) {
|
||||
re_hi = &buffer[blockIdx.x * params::degree / 2 * 4 + params::degree / 2 * 0];
|
||||
re_lo = &buffer[blockIdx.x * params::degree / 2 * 4 + params::degree / 2 * 1];
|
||||
im_hi = &buffer[blockIdx.x * params::degree / 2 * 4 + params::degree / 2 * 2];
|
||||
im_lo = &buffer[blockIdx.x * params::degree / 2 * 4 + params::degree / 2 * 3];
|
||||
} else {
|
||||
re_hi = &sharedMemoryFFT[params::degree / 2 * 0];
|
||||
re_lo = &sharedMemoryFFT[params::degree / 2 * 1];
|
||||
im_hi = &sharedMemoryFFT[params::degree / 2 * 2];
|
||||
im_lo = &sharedMemoryFFT[params::degree / 2 * 3];
|
||||
}
|
||||
|
||||
Index tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < params::opt / 2; ++i) {
|
||||
re_hi[tid] = in_re_hi[blockIdx.x * (params::degree / 2) + tid];
|
||||
re_lo[tid] = in_re_lo[blockIdx.x * (params::degree / 2) + tid];
|
||||
im_hi[tid] = in_im_hi[blockIdx.x * (params::degree / 2) + tid];
|
||||
im_lo[tid] = in_im_lo[blockIdx.x * (params::degree / 2) + tid];
|
||||
}
|
||||
__syncthreads();
|
||||
negacyclic_forward_fft_f128<HalfDegree<params>>(re_hi, re_lo, im_hi, im_lo);
|
||||
__syncthreads();
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < params::opt / 2; ++i) {
|
||||
out_re_hi[blockIdx.x * (params::degree / 2) + tid] = re_hi[tid];
|
||||
out_re_lo[blockIdx.x * (params::degree / 2) + tid] = re_lo[tid];
|
||||
out_im_hi[blockIdx.x * (params::degree / 2) + tid] = im_hi[tid];
|
||||
out_im_lo[blockIdx.x * (params::degree / 2) + tid] = im_lo[tid];
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void print_uint128_bits(__uint128_t value) {
|
||||
char buffer[129]; // 128 bits + null terminator
|
||||
buffer[128] = '\0'; // Null-terminate the string
|
||||
|
||||
for (int i = 127; i >= 0; --i) {
|
||||
buffer[i] = (value & 1) ? '1' : '0'; // Extract the least significant bit
|
||||
value >>= 1; // Shift right by 1 bit
|
||||
}
|
||||
|
||||
printf("%s\n", buffer);
|
||||
}
|
||||
|
||||
template <class params>
|
||||
__host__ void host_fourier_transform_forward_f128_split_input(
|
||||
cudaStream_t stream, uint32_t gpu_index, double *re0, double *re1,
|
||||
double *im0, double *im1, __uint128_t const *standard, uint32_t const N) {
|
||||
|
||||
printf("cpp_poly_host\n");
|
||||
for (int i = 0; i < N; i++) {
|
||||
print_uint128_bits(standard[i]);
|
||||
}
|
||||
printf("check #1\n");
|
||||
|
||||
double *d_re0, *d_re1, *d_im0, *d_im1;
|
||||
__uint128_t *d_standard;
|
||||
|
||||
check_cuda_error(cudaMalloc((void **)&d_re0, N * sizeof(double)));
|
||||
check_cuda_error(cudaMalloc((void **)&d_re1, N * sizeof(double)));
|
||||
check_cuda_error(cudaMalloc((void **)&d_im0, N * sizeof(double)));
|
||||
check_cuda_error(cudaMalloc((void **)&d_im1, N * sizeof(double)));
|
||||
|
||||
check_cuda_error(cudaMalloc((void **)&d_standard, N * sizeof(__uint128_t)));
|
||||
|
||||
check_cuda_error(cudaFree(d_re0));
|
||||
check_cuda_error(cudaFree(d_re1));
|
||||
check_cuda_error(cudaFree(d_im0));
|
||||
check_cuda_error(cudaFree(d_im1));
|
||||
|
||||
cudaFree(d_standard);
|
||||
}
|
||||
|
||||
#endif // TFHE_RS_BACKENDS_TFHE_CUDA_BACKEND_CUDA_SRC_FFT128_FFT128_CUH_
|
||||
5476
backends/tfhe-cuda-backend/cuda/src/fft128/twiddles.cu
Normal file
5476
backends/tfhe-cuda-backend/cuda/src/fft128/twiddles.cu
Normal file
File diff suppressed because it is too large
Load Diff
11
backends/tfhe-cuda-backend/cuda/src/fft128/twiddles.cuh
Normal file
11
backends/tfhe-cuda-backend/cuda/src/fft128/twiddles.cuh
Normal file
@@ -0,0 +1,11 @@
|
||||
#ifndef GPU_BOOTSTRAP_128_TWIDDLES_CUH
|
||||
#define GPU_BOOTSTRAP_128_TWIDDLES_CUH
|
||||
|
||||
/*
|
||||
* 'negtwiddles' are stored in device memory to profit caching
|
||||
*/
|
||||
extern __device__ double neg_twiddles_re_hi[4096];
|
||||
extern __device__ double neg_twiddles_re_lo[4096];
|
||||
extern __device__ double neg_twiddles_im_hi[4096];
|
||||
extern __device__ double neg_twiddles_im_lo[4096];
|
||||
#endif
|
||||
@@ -1238,6 +1238,18 @@ extern "C" {
|
||||
input_lwe_ciphertext_count: u32,
|
||||
);
|
||||
}
|
||||
extern "C" {
|
||||
pub fn fourier_transform_forward_f128(
|
||||
stream: *mut ffi::c_void,
|
||||
gpu_index: u32,
|
||||
re0: *mut ffi::c_void,
|
||||
re1: *mut ffi::c_void,
|
||||
im0: *mut ffi::c_void,
|
||||
im1: *mut ffi::c_void,
|
||||
standard: *const ffi::c_void,
|
||||
N: u32,
|
||||
);
|
||||
}
|
||||
extern "C" {
|
||||
pub fn cuda_fourier_polynomial_mul(
|
||||
stream: *mut ffi::c_void,
|
||||
|
||||
@@ -4,5 +4,6 @@
|
||||
#include "cuda/include/integer/integer.h"
|
||||
#include "cuda/include/keyswitch.h"
|
||||
#include "cuda/include/linear_algebra.h"
|
||||
#include "cuda/include/pbs/fft.h"
|
||||
#include "cuda/include/pbs/programmable_bootstrap.h"
|
||||
#include "cuda/include/pbs/programmable_bootstrap_multibit.h"
|
||||
|
||||
35
parse_f128_twiddles.py
Normal file
35
parse_f128_twiddles.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import sys
|
||||
|
||||
def generate_twiddles(file_path):
|
||||
try:
|
||||
with open(file_path, "r") as file:
|
||||
lines = file.readlines()
|
||||
|
||||
# Parse n
|
||||
n_line = lines[0].strip()
|
||||
n = int(n_line.split('=')[1].strip())
|
||||
|
||||
# Parse twiddle data
|
||||
twiddles = []
|
||||
for line in lines[1:]:
|
||||
if "twid_re_hi" in line:
|
||||
parts = line.split(':')[1].strip().split(',')
|
||||
hex_val = parts[0].strip()
|
||||
float_val = parts[1].strip()
|
||||
twiddles.append((hex_val, float_val))
|
||||
|
||||
# Generate C++ code
|
||||
cpp_code = f"double negtwiddles[{n}] = {{\n"
|
||||
for hex_val, float_val in twiddles:
|
||||
cpp_code += f" {float_val},\n"
|
||||
cpp_code += "};\n"
|
||||
|
||||
print(cpp_code)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python generate_twiddles.py <file_path>")
|
||||
else:
|
||||
generate_twiddles(sys.argv[1])
|
||||
@@ -140,6 +140,7 @@ pub fn u128_to_f64(x: u128) -> f64 {
|
||||
const B: f64 = (1u128 << 104) as f64;
|
||||
const C: f64 = (1u128 << 76) as f64;
|
||||
const D: f64 = u128::MAX as f64;
|
||||
|
||||
if x < 1 << 104 {
|
||||
let l = f64::from_bits(A.to_bits() | ((x << 12) as u64 >> 12)) - A;
|
||||
let h = f64::from_bits(B.to_bits() | (x >> 52) as u64) - B;
|
||||
|
||||
@@ -19,6 +19,7 @@ fn test_roundtrip<Scalar: UnsignedTorus>() {
|
||||
let mut fourier_im0 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
let mut fourier_im1 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
|
||||
println!("sizeof_scalar: {:?}", Scalar::BITS);
|
||||
for x in poly.as_mut().iter_mut() {
|
||||
*x = generator.random_uniform();
|
||||
}
|
||||
|
||||
245
tfhe/src/core_crypto/gpu/algorithms/test/fft.rs
Normal file
245
tfhe/src/core_crypto/gpu/algorithms/test/fft.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
use super::*;
|
||||
use crate::core_crypto::commons::test_tools::{modular_distance, new_random_generator};
|
||||
use crate::core_crypto::commons::utils::izip;
|
||||
use crate::core_crypto::gpu::{fourier_transform_forward_f128_async, CudaStreams};
|
||||
use aligned_vec::avec;
|
||||
use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut};
|
||||
|
||||
fn test_roundtrip<Scalar: UnsignedTorus>() {
|
||||
let mut generator = new_random_generator();
|
||||
for size_log in 10..=10 {
|
||||
let size = 1_usize << size_log;
|
||||
let fourier_size = PolynomialSize(size).to_fourier_polynomial_size().0;
|
||||
|
||||
let fft = Fft128::new(PolynomialSize(size));
|
||||
let fft = fft.as_view();
|
||||
|
||||
let mut poly = avec![Scalar::ZERO; size].into_boxed_slice();
|
||||
let mut roundtrip = avec![Scalar::ZERO; size].into_boxed_slice();
|
||||
let mut fourier_re0 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
let mut fourier_re1 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
let mut fourier_im0 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
let mut fourier_im1 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
|
||||
println!("sizeof_scalar: {:?}", Scalar::BITS);
|
||||
for x in poly.as_mut().iter_mut() {
|
||||
*x = generator.random_uniform();
|
||||
}
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(fft.backward_scratch().unwrap());
|
||||
let mut stack = PodStack::new(&mut mem);
|
||||
|
||||
fft.forward_as_torus(
|
||||
&mut fourier_re0,
|
||||
&mut fourier_re1,
|
||||
&mut fourier_im0,
|
||||
&mut fourier_im1,
|
||||
&poly,
|
||||
);
|
||||
let gpu_index = 0;
|
||||
let stream = CudaStreams::new_single_gpu(gpu_index);
|
||||
|
||||
unsafe {
|
||||
println!("size: {:?}", size);
|
||||
println!("poly.len: {:?}", poly.len());
|
||||
println!("rust poly");
|
||||
for coefficient in poly.iter() {
|
||||
println!(
|
||||
"{:0width$b}",
|
||||
coefficient,
|
||||
width = std::mem::size_of::<Scalar>() * 8
|
||||
);
|
||||
}
|
||||
fourier_transform_forward_f128_async(
|
||||
&stream,
|
||||
&mut fourier_re0,
|
||||
&mut fourier_re1,
|
||||
&mut fourier_im0,
|
||||
&mut fourier_im1,
|
||||
&poly,
|
||||
poly.len() as u32,
|
||||
);
|
||||
}
|
||||
|
||||
fft.backward_as_torus(
|
||||
&mut roundtrip,
|
||||
&fourier_re0,
|
||||
&fourier_re1,
|
||||
&fourier_im0,
|
||||
&fourier_im1,
|
||||
stack.rb_mut(),
|
||||
);
|
||||
|
||||
for (expected, actual) in izip!(poly.as_ref().iter(), roundtrip.as_ref().iter()) {
|
||||
if Scalar::BITS <= 64 {
|
||||
assert_eq!(*expected, *actual);
|
||||
} else {
|
||||
let abs_diff = modular_distance(*expected, *actual);
|
||||
let threshold = Scalar::ONE << (128 - 100);
|
||||
assert!(
|
||||
abs_diff < threshold,
|
||||
"abs_diff: {abs_diff}, threshold: {threshold}",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fn test_product<Scalar: UnsignedTorus>() {
|
||||
// fn convolution_naive<Scalar: UnsignedTorus>(
|
||||
// out: &mut [Scalar],
|
||||
// lhs: &[Scalar],
|
||||
// rhs: &[Scalar],
|
||||
// ) {
|
||||
// assert_eq!(out.len(), lhs.len());
|
||||
// assert_eq!(out.len(), rhs.len());
|
||||
// let n = out.len();
|
||||
// let mut full_prod = vec![Scalar::ZERO; 2 * n];
|
||||
// for i in 0..n {
|
||||
// for j in 0..n {
|
||||
// full_prod[i + j] = full_prod[i + j].wrapping_add(lhs[i].wrapping_mul(rhs[j]));
|
||||
// }
|
||||
// }
|
||||
// for i in 0..n {
|
||||
// out[i] = full_prod[i].wrapping_sub(full_prod[i + n]);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// let mut generator = new_random_generator();
|
||||
// for size_log in 6..=14 {
|
||||
// for _ in 0..10 {
|
||||
// let size = 1_usize << size_log;
|
||||
// let fourier_size = PolynomialSize(size).to_fourier_polynomial_size().0;
|
||||
//
|
||||
// let fft = Fft128::new(PolynomialSize(size));
|
||||
// let fft = fft.as_view();
|
||||
//
|
||||
// let mut poly0 = avec![Scalar::ZERO; size].into_boxed_slice();
|
||||
// let mut poly1 = avec![Scalar::ZERO; size].into_boxed_slice();
|
||||
//
|
||||
// let mut convolution_from_fft = avec![Scalar::ZERO; size].into_boxed_slice();
|
||||
// let mut convolution_from_naive = avec![Scalar::ZERO; size].into_boxed_slice();
|
||||
//
|
||||
// let mut fourier0_re0 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
// let mut fourier0_re1 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
// let mut fourier0_im0 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
// let mut fourier0_im1 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
//
|
||||
// let mut fourier1_re0 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
// let mut fourier1_re1 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
// let mut fourier1_im0 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
// let mut fourier1_im1 = avec![0.0f64; fourier_size].into_boxed_slice();
|
||||
//
|
||||
// let integer_magnitude = 16;
|
||||
// for (x, y) in izip!(poly0.as_mut().iter_mut(), poly1.as_mut().iter_mut()) {
|
||||
// *x = generator.random_uniform();
|
||||
// *y = generator.random_uniform();
|
||||
//
|
||||
// *y >>= Scalar::BITS - integer_magnitude;
|
||||
// }
|
||||
//
|
||||
// let mut mem = GlobalPodBuffer::new(fft.backward_scratch().unwrap());
|
||||
// let mut stack = PodStack::new(&mut mem);
|
||||
//
|
||||
// fft.forward_as_torus(
|
||||
// &mut fourier0_re0,
|
||||
// &mut fourier0_re1,
|
||||
// &mut fourier0_im0,
|
||||
// &mut fourier0_im1,
|
||||
// &poly0,
|
||||
// );
|
||||
// fft.forward_as_integer(
|
||||
// &mut fourier1_re0,
|
||||
// &mut fourier1_re1,
|
||||
// &mut fourier1_im0,
|
||||
// &mut fourier1_im1,
|
||||
// &poly1,
|
||||
// );
|
||||
//
|
||||
// for (f0_re0, f0_re1, f0_im0, f0_im1, f1_re0, f1_re1, f1_im0, f1_im1) in izip!(
|
||||
// &mut *fourier0_re0,
|
||||
// &mut *fourier0_re1,
|
||||
// &mut *fourier0_im0,
|
||||
// &mut *fourier0_im1,
|
||||
// &*fourier1_re0,
|
||||
// &*fourier1_re1,
|
||||
// &*fourier1_im0,
|
||||
// &*fourier1_im1,
|
||||
// ) {
|
||||
// let f0_re = f128(*f0_re0, *f0_re1);
|
||||
// let f0_im = f128(*f0_im0, *f0_im1);
|
||||
// let f1_re = f128(*f1_re0, *f1_re1);
|
||||
// let f1_im = f128(*f1_im0, *f1_im1);
|
||||
//
|
||||
// f128(*f0_re0, *f0_re1) = f0_re * f1_re - f0_im * f1_im;
|
||||
// f128(*f0_im0, *f0_im1) = f0_im * f1_re + f0_re * f1_im;
|
||||
// }
|
||||
//
|
||||
// fft.backward_as_torus(
|
||||
// &mut convolution_from_fft,
|
||||
// &fourier0_re0,
|
||||
// &fourier0_re1,
|
||||
// &fourier0_im0,
|
||||
// &fourier0_im1,
|
||||
// stack.rb_mut(),
|
||||
// );
|
||||
// convolution_naive(
|
||||
// convolution_from_naive.as_mut(),
|
||||
// poly0.as_ref(),
|
||||
// poly1.as_ref(),
|
||||
// );
|
||||
//
|
||||
// for (expected, actual) in izip!(
|
||||
// convolution_from_naive.as_ref().iter(),
|
||||
// convolution_from_fft.as_ref().iter()
|
||||
// ) {
|
||||
// let threshold = Scalar::ONE
|
||||
// << (Scalar::BITS.saturating_sub(100 - integer_magnitude - size_log));
|
||||
// let abs_diff = modular_distance(*expected, *actual);
|
||||
// assert!(
|
||||
// abs_diff <= threshold,
|
||||
// "abs_diff: {abs_diff}, threshold: {threshold}",
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_roundtrip_u32() {
|
||||
// test_roundtrip::<u32>();
|
||||
// }
|
||||
// #[test]
|
||||
// fn test_roundtrip_u64() {
|
||||
// test_roundtrip::<u64>();
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_u128() {
|
||||
test_roundtrip::<u128>();
|
||||
}
|
||||
|
||||
// fn test_roundtrip_u128<
|
||||
// Scalar: UnsignedTorus + Sync + Send + CastFrom<usize> + CastInto<usize>,
|
||||
// >(
|
||||
// params: ClassicTestParams<Scalar>,
|
||||
// ) {
|
||||
// test_roundtrip::<u128>();
|
||||
// }
|
||||
|
||||
// create_gpu_parametrized_test!(test_roundtrip_u128);
|
||||
|
||||
// #[test]
|
||||
// fn test_product_u32() {
|
||||
// test_product::<u32>();
|
||||
// }
|
||||
//
|
||||
// #[test]
|
||||
// fn test_product_u64() {
|
||||
// test_product::<u64>();
|
||||
// }
|
||||
//
|
||||
// #[test]
|
||||
// fn test_product_u128() {
|
||||
// test_product::<u128>();
|
||||
// }
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::core_crypto::algorithms::test::*;
|
||||
|
||||
mod fft;
|
||||
mod glwe_sample_extraction;
|
||||
mod lwe_keyswitch;
|
||||
mod lwe_linear_algebra;
|
||||
|
||||
@@ -630,6 +630,27 @@ pub unsafe fn mult_lwe_ciphertext_vector_cleartext_vector<T: UnsignedInteger>(
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub unsafe fn fourier_transform_forward_f128_async<T: UnsignedInteger>(
|
||||
streams: &CudaStreams,
|
||||
re0: &mut [f64],
|
||||
re1: &mut [f64],
|
||||
im0: &mut [f64],
|
||||
im1: &mut [f64],
|
||||
standard: &[T],
|
||||
fft_size: u32,
|
||||
) {
|
||||
fourier_transform_forward_f128(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0],
|
||||
re0.as_mut_ptr() as *mut c_void,
|
||||
re1.as_mut_ptr() as *mut c_void,
|
||||
im0.as_mut_ptr() as *mut c_void,
|
||||
im1.as_mut_ptr() as *mut c_void,
|
||||
standard.as_ptr() as *const c_void,
|
||||
fft_size,
|
||||
);
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct CudaLweList<T: UnsignedInteger> {
|
||||
// Pointer to GPU data
|
||||
|
||||
Reference in New Issue
Block a user