mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
chore(cuda): replace casts with cuda intrinsics
This commit is contained in:
@@ -103,9 +103,9 @@ __global__ void device_bootstrap_amortized(
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count);
|
||||
|
||||
// Put "b", the body, in [0, 2N[
|
||||
Torus b_hat = rescale_torus_element(
|
||||
block_lwe_array_in[lwe_dimension],
|
||||
2 * params::degree); // 2 * params::log2_degree + 1);
|
||||
Torus b_hat = 0;
|
||||
rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat,
|
||||
2 * params::degree); // 2 * params::log2_degree + 1);
|
||||
|
||||
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
|
||||
params::degree / params::opt>(
|
||||
@@ -123,9 +123,9 @@ __global__ void device_bootstrap_amortized(
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Put "a" in [0, 2N[ instead of Zq
|
||||
Torus a_hat = rescale_torus_element(
|
||||
block_lwe_array_in[iteration],
|
||||
2 * params::degree); // 2 * params::log2_degree + 1);
|
||||
Torus a_hat = 0;
|
||||
rescale_torus_element(block_lwe_array_in[iteration], a_hat,
|
||||
2 * params::degree); // 2 * params::log2_degree + 1);
|
||||
|
||||
// Perform ACC * (X^ä - 1)
|
||||
multiply_by_monomial_negacyclic_and_sub_polynomial<
|
||||
@@ -176,8 +176,8 @@ __global__ void device_bootstrap_amortized(
|
||||
|
||||
// Reduce the size of the FFT to be performed by storing
|
||||
// the real-valued polynomial into a complex polynomial
|
||||
real_to_complex_compressed<int16_t, params>(accumulator_mask_decomposed,
|
||||
accumulator_fft);
|
||||
real_to_complex_compressed<params>(accumulator_mask_decomposed,
|
||||
accumulator_fft);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
// Switch to the FFT space
|
||||
@@ -208,8 +208,8 @@ __global__ void device_bootstrap_amortized(
|
||||
|
||||
// Now handle the polynomial multiplication for the body
|
||||
// in the same way
|
||||
real_to_complex_compressed<int16_t, params>(accumulator_body_decomposed,
|
||||
accumulator_fft);
|
||||
real_to_complex_compressed<params>(accumulator_body_decomposed,
|
||||
accumulator_fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
|
||||
|
||||
@@ -36,7 +36,7 @@ mul_ggsw_glwe(Torus *accumulator, double2 *fft, int16_t *glwe_decomposed,
|
||||
int iteration, grid_group &grid) {
|
||||
|
||||
// Put the decomposed GLWE sample in the Fourier domain
|
||||
real_to_complex_compressed<int16_t, params>(glwe_decomposed, fft);
|
||||
real_to_complex_compressed<params>(glwe_decomposed, fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Switch to the FFT space
|
||||
@@ -184,8 +184,9 @@ __global__ void device_bootstrap_low_latency(
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count);
|
||||
|
||||
// Put "b" in [0, 2N[
|
||||
Torus b_hat = rescale_torus_element(block_lwe_array_in[lwe_dimension],
|
||||
2 * params::degree);
|
||||
Torus b_hat = 0;
|
||||
rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat,
|
||||
2 * params::degree);
|
||||
|
||||
if (blockIdx.y == 0) {
|
||||
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
|
||||
@@ -201,9 +202,9 @@ __global__ void device_bootstrap_low_latency(
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Put "a" in [0, 2N[
|
||||
Torus a_hat = rescale_torus_element(
|
||||
block_lwe_array_in[i],
|
||||
2 * params::degree); // 2 * params::log2_degree + 1);
|
||||
Torus a_hat = 0;
|
||||
rescale_torus_element(block_lwe_array_in[i], a_hat,
|
||||
2 * params::degree); // 2 * params::log2_degree + 1);
|
||||
|
||||
// Perform ACC * (X^ä - 1)
|
||||
multiply_by_monomial_negacyclic_and_sub_polynomial<
|
||||
|
||||
@@ -21,30 +21,12 @@
|
||||
#include "utils/memory.cuh"
|
||||
#include "utils/timer.cuh"
|
||||
|
||||
template <typename T, class params>
|
||||
__device__ void fft(double2 *output, T *input) {
|
||||
template <class params> __device__ void fft(double2 *output, int16_t *input) {
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Reduce the size of the FFT to be performed by storing
|
||||
// the real-valued polynomial into a complex polynomial
|
||||
real_to_complex_compressed<T, params>(input, output);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Switch to the FFT space
|
||||
NSMFFT_direct<HalfDegree<params>>(output);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_direct_fft_inplace<params>(output);
|
||||
synchronize_threads_in_block();
|
||||
}
|
||||
|
||||
template <typename T, typename ST, class params>
|
||||
__device__ void fft(double2 *output, T *input) {
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Reduce the size of the FFT to be performed by storing
|
||||
// the real-valued polynomial into a complex polynomial
|
||||
real_to_complex_compressed<T, ST, params>(input, output);
|
||||
real_to_complex_compressed<params>(input, output);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Switch to the FFT space
|
||||
@@ -154,7 +136,7 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
|
||||
// First, perform the polynomial multiplication for the mask
|
||||
synchronize_threads_in_block();
|
||||
fft<int16_t, params>(glwe_fft, glwe_mask_decomposed);
|
||||
fft<params>(glwe_fft, glwe_mask_decomposed);
|
||||
|
||||
// External product and accumulate
|
||||
// Get the piece necessary for the multiplication
|
||||
@@ -175,7 +157,7 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
// Now handle the polynomial multiplication for the body
|
||||
// in the same way
|
||||
synchronize_threads_in_block();
|
||||
fft<int16_t, params>(glwe_fft, glwe_body_decomposed);
|
||||
fft<params>(glwe_fft, glwe_body_decomposed);
|
||||
|
||||
// External product and accumulate
|
||||
// Get the piece necessary for the multiplication
|
||||
|
||||
@@ -4,21 +4,26 @@
|
||||
#include "types/int128.cuh"
|
||||
#include <limits>
|
||||
|
||||
template <typename Torus>
|
||||
__device__ inline Torus typecast_double_to_torus(double x) {
|
||||
if constexpr (sizeof(Torus) < 8) {
|
||||
// this simple cast works up to 32 bits, afterwards we must do it manually
|
||||
long long ret = x;
|
||||
return (Torus)ret;
|
||||
} else {
|
||||
int128 nnnn = make_int128_from_float(x);
|
||||
uint64_t lll = nnnn.lo_;
|
||||
return lll;
|
||||
}
|
||||
// nvcc doesn't get it that the if {} else {} above should always return
|
||||
// something, and complains that this function might return nothing, so we
|
||||
// put this useless return here
|
||||
return 0;
|
||||
template <typename T>
|
||||
__device__ inline void typecast_double_to_torus(double x, T &r) {
|
||||
r = T(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void typecast_double_to_torus<uint32_t>(double x,
|
||||
uint32_t &r) {
|
||||
r = __double2uint_rn(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void typecast_double_to_torus<uint64_t>(double x,
|
||||
uint64_t &r) {
|
||||
// The ull intrinsic does not behave in the same way on all architectures and
|
||||
// on some platforms this causes the cmux tree test to fail
|
||||
// Hence the intrinsic is not used here
|
||||
uint128 nnnn = make_uint128_from_float(x);
|
||||
uint64_t lll = nnnn.lo_;
|
||||
r = lll;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -34,10 +39,30 @@ __device__ inline T round_to_closest_multiple(T x, uint32_t base_log,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T rescale_torus_element(T element,
|
||||
uint32_t log_shift) {
|
||||
return round((double)element / (double(std::numeric_limits<T>::max()) + 1.0) *
|
||||
(double)log_shift);
|
||||
__device__ __forceinline__ void rescale_torus_element(T element, T &output,
|
||||
uint32_t log_shift) {
|
||||
output =
|
||||
round((double)element / (double(std::numeric_limits<T>::max()) + 1.0) *
|
||||
(double)log_shift);
|
||||
;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void
|
||||
rescale_torus_element<uint32_t>(uint32_t element, uint32_t &output,
|
||||
uint32_t log_shift) {
|
||||
output =
|
||||
round(__uint2double_rn(element) /
|
||||
(__uint2double_rn(std::numeric_limits<uint32_t>::max()) + 1.0) *
|
||||
__uint2double_rn(log_shift));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void
|
||||
rescale_torus_element<uint64_t>(uint64_t element, uint64_t &output,
|
||||
uint32_t log_shift) {
|
||||
output = round(__ull2double_rn(element) /
|
||||
(__ull2double_rn(std::numeric_limits<uint64_t>::max()) + 1.0) *
|
||||
__uint2double_rn(log_shift));
|
||||
}
|
||||
#endif // CNCRT_TORUS_H
|
||||
@@ -6,29 +6,13 @@
|
||||
/*
|
||||
* function compresses decomposed buffer into half size complex buffer for fft
|
||||
*/
|
||||
template <typename T, class params>
|
||||
__device__ void real_to_complex_compressed(T *src, double2 *dst) {
|
||||
template <class params>
|
||||
__device__ void real_to_complex_compressed(int16_t *src, double2 *dst) {
|
||||
int tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
dst[tid].x = (double)src[2 * tid];
|
||||
dst[tid].y = (double)src[2 * tid + 1];
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename ST, class params>
|
||||
__device__ void real_to_complex_compressed(T *src, double2 *dst) {
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
ST x = src[2 * tid];
|
||||
ST y = src[2 * tid + 1];
|
||||
|
||||
dst[tid].x = x / (double)std::numeric_limits<T>::max();
|
||||
dst[tid].y = y / (double)std::numeric_limits<T>::max();
|
||||
|
||||
dst[tid].x = __int2double_rn(src[2 * tid]);
|
||||
dst[tid].y = __int2double_rn(src[2 * tid + 1]);
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
@@ -175,14 +159,16 @@ __device__ void add_to_torus(double2 *m_values, Torus *result) {
|
||||
double carry = frac - floor(frac);
|
||||
frac += (carry >= 0.5);
|
||||
|
||||
Torus V1 = typecast_double_to_torus<Torus>(frac);
|
||||
Torus V1 = 0;
|
||||
typecast_double_to_torus<Torus>(frac, V1);
|
||||
|
||||
frac = v2 - floor(v2);
|
||||
frac *= mx;
|
||||
carry = frac - floor(v2);
|
||||
frac += (carry >= 0.5);
|
||||
|
||||
Torus V2 = typecast_double_to_torus<Torus>(frac);
|
||||
Torus V2 = 0;
|
||||
typecast_double_to_torus<Torus>(frac, V2);
|
||||
|
||||
result[tid * 2] += V1;
|
||||
result[tid * 2 + 1] += V2;
|
||||
|
||||
Reference in New Issue
Block a user