diff --git a/src/bootstrap_amortized.cuh b/src/bootstrap_amortized.cuh index ff2b8777e..2e07fc405 100644 --- a/src/bootstrap_amortized.cuh +++ b/src/bootstrap_amortized.cuh @@ -72,14 +72,8 @@ __global__ void device_bootstrap_amortized( selected_memory = &device_mem[blockIdx.x * device_memory_size_per_sample]; // For GPU bootstrapping the GLWE dimension is hard-set to 1: there is only - // one mask polynomial and 1 body to handle Also, since the decomposed - // polynomials take coefficients between -B/2 and B/2 they can be represented - // with only 16 bits, assuming the base log does not exceed 2^16 - int16_t *accumulator_mask_decomposed = (int16_t *)selected_memory; - int16_t *accumulator_body_decomposed = - (int16_t *)accumulator_mask_decomposed + polynomial_size; - Torus *accumulator_mask = (Torus *)accumulator_body_decomposed + - polynomial_size / (sizeof(Torus) / sizeof(int16_t)); + // one mask polynomial and 1 body to handle. + Torus *accumulator_mask = (Torus *)selected_memory; Torus *accumulator_body = (Torus *)accumulator_mask + (ptrdiff_t)polynomial_size; Torus *accumulator_mask_rotated = @@ -100,8 +94,6 @@ __global__ void device_bootstrap_amortized( &lut_vector[lut_vector_indexes[lwe_idx + blockIdx.x] * params::degree * 2]; - GadgetMatrix gadget(base_log, level_count); - // Put "b", the body, in [0, 2N[ Torus b_hat = 0; rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat, @@ -159,27 +151,17 @@ __global__ void device_bootstrap_amortized( pos += params::degree / params::opt; } + GadgetMatrix gadget_mask(base_log, level_count, + accumulator_mask_rotated); + GadgetMatrix gadget_body(base_log, level_count, + accumulator_body_rotated); + // Now that the rotation is done, decompose the resulting polynomial // coefficients so as to multiply each decomposed level with the // corresponding part of the bootstrapping key - for (int level = 0; level < level_count; level++) { + for (int level = level_count - 1; level >= 0; level--) { - gadget.decompose_one_level(accumulator_mask_decomposed, - accumulator_mask_rotated, level); - - gadget.decompose_one_level(accumulator_body_decomposed, - accumulator_body_rotated, level); - - synchronize_threads_in_block(); - - // First, perform the polynomial multiplication for the mask - - // Reduce the size of the FFT to be performed by storing - // the real-valued polynomial into a complex polynomial - real_to_complex_compressed(accumulator_mask_decomposed, - accumulator_fft); - - synchronize_threads_in_block(); + gadget_mask.decompose_and_compress_next(accumulator_fft); // Switch to the FFT space NSMFFT_direct>(accumulator_fft); synchronize_threads_in_block(); @@ -205,12 +187,10 @@ __global__ void device_bootstrap_amortized( body_res_fft, accumulator_fft, bsk_body_slice); synchronize_threads_in_block(); + gadget_body.decompose_and_compress_next(accumulator_fft); // Now handle the polynomial multiplication for the body // in the same way - real_to_complex_compressed(accumulator_body_decomposed, - accumulator_fft); - synchronize_threads_in_block(); NSMFFT_direct>(accumulator_fft); synchronize_threads_in_block(); @@ -306,12 +286,10 @@ __host__ void host_bootstrap_amortized( uint32_t gpu_index = 0; - int SM_FULL = sizeof(Torus) * polynomial_size + // accumulator mask - sizeof(Torus) * polynomial_size + // accumulator body - sizeof(Torus) * polynomial_size + // accumulator mask rotated - sizeof(Torus) * polynomial_size + // accumulator body rotated - sizeof(int16_t) * polynomial_size + // accumulator_dec mask - sizeof(int16_t) * polynomial_size + // accumulator_dec_body + int SM_FULL = sizeof(Torus) * polynomial_size + // accumulator mask + sizeof(Torus) * polynomial_size + // accumulator body + sizeof(Torus) * polynomial_size + // accumulator mask rotated + sizeof(Torus) * polynomial_size + // accumulator body rotated sizeof(double2) * polynomial_size / 2 + // accumulator fft mask sizeof(double2) * polynomial_size / 2 + // accumulator fft body sizeof(double2) * polynomial_size / 2; // calculate buffer fft diff --git a/src/bootstrap_low_latency.cuh b/src/bootstrap_low_latency.cuh index 7bd73fa8f..0a1d046bf 100644 --- a/src/bootstrap_low_latency.cuh +++ b/src/bootstrap_low_latency.cuh @@ -30,14 +30,10 @@ namespace cg = cooperative_groups; template __device__ void -mul_ggsw_glwe(Torus *accumulator, double2 *fft, int16_t *glwe_decomposed, - double2 *mask_join_buffer, double2 *body_join_buffer, - double2 *bootstrapping_key, int polynomial_size, int level_count, - int iteration, grid_group &grid) { - - // Put the decomposed GLWE sample in the Fourier domain - real_to_complex_compressed(glwe_decomposed, fft); - synchronize_threads_in_block(); +mul_ggsw_glwe(Torus *accumulator, double2 *fft, double2 *mask_join_buffer, + double2 *body_join_buffer, double2 *bootstrapping_key, + int polynomial_size, int level_count, int iteration, + grid_group &grid) { // Switch to the FFT space NSMFFT_direct>(fft); @@ -157,16 +153,13 @@ __global__ void device_bootstrap_low_latency( char *selected_memory = sharedmem; - int16_t *accumulator_decomposed = (int16_t *)selected_memory; - Torus *accumulator = (Torus *)accumulator_decomposed + - polynomial_size / (sizeof(Torus) / sizeof(int16_t)); + Torus *accumulator = (Torus *)selected_memory; + Torus *accumulator_rotated = + (Torus *)accumulator + (ptrdiff_t)polynomial_size; double2 *accumulator_fft = - (double2 *)accumulator + + (double2 *)accumulator_rotated + polynomial_size / (sizeof(double2) / sizeof(Torus)); - // Reuse memory from accumulator_fft for accumulator_rotated - Torus *accumulator_rotated = (Torus *)accumulator_fft; - // The third dimension of the block is used to determine on which ciphertext // this block is operating, in the case of batch bootstraps auto block_lwe_array_in = &lwe_array_in[blockIdx.z * (lwe_dimension + 1)]; @@ -181,7 +174,6 @@ __global__ void device_bootstrap_low_latency( // Since the space is L1 cache is small, we use the same memory location for // the rotated accumulator and the fft accumulator, since we know that the // rotated array is not in use anymore by the time we perform the fft - GadgetMatrix gadget(base_log, level_count); // Put "b" in [0, 2N[ Torus b_hat = 0; @@ -217,11 +209,14 @@ __global__ void device_bootstrap_low_latency( params::degree / params::opt>( accumulator_rotated, base_log, level_count); + synchronize_threads_in_block(); + // Decompose the accumulator. Each block gets one level of the // decomposition, for the mask and the body (so block 0 will have the // accumulator decomposed at level 0, 1 at 1, etc.) - gadget.decompose_one_level(accumulator_decomposed, accumulator_rotated, - blockIdx.x); + GadgetMatrix gadget_acc(base_log, level_count, + accumulator_rotated); + gadget_acc.decompose_and_compress_level(accumulator_fft, blockIdx.x); // We are using the same memory space for accumulator_fft and // accumulator_rotated, so we need to synchronize here to make sure they @@ -229,9 +224,9 @@ __global__ void device_bootstrap_low_latency( synchronize_threads_in_block(); // Perform G^-1(ACC) * GGSW -> GLWE mul_ggsw_glwe(accumulator, accumulator_fft, - accumulator_decomposed, block_mask_join_buffer, - block_body_join_buffer, bootstrapping_key, - polynomial_size, level_count, i, grid); + block_mask_join_buffer, block_body_join_buffer, + bootstrapping_key, polynomial_size, + level_count, i, grid); } auto block_lwe_array_out = &lwe_array_out[blockIdx.z * (polynomial_size + 1)]; @@ -270,8 +265,8 @@ host_bootstrap_low_latency(void *v_stream, Torus *lwe_array_out, double2 *body_buffer_fft = (double2 *)cuda_malloc_async(buffer_size_per_gpu, *stream, gpu_index); - int bytes_needed = sizeof(int16_t) * polynomial_size + // accumulator_decomp - sizeof(Torus) * polynomial_size + // accumulator + int bytes_needed = sizeof(Torus) * polynomial_size + // accumulator_rotated + sizeof(Torus) * polynomial_size + // accumulator sizeof(double2) * polynomial_size / 2; // accumulator fft int thds = polynomial_size / params::opt; diff --git a/src/bootstrap_wop.cuh b/src/bootstrap_wop.cuh index 30a887134..e3c791592 100644 --- a/src/bootstrap_wop.cuh +++ b/src/bootstrap_wop.cuh @@ -21,14 +21,7 @@ #include "utils/memory.cuh" #include "utils/timer.cuh" -template __device__ void fft(double2 *output, int16_t *input) { - synchronize_threads_in_block(); - - // Reduce the size of the FFT to be performed by storing - // the real-valued polynomial into a complex polynomial - real_to_complex_compressed(input, output); - synchronize_threads_in_block(); - +template __device__ void fft(double2 *output) { // Switch to the FFT space NSMFFT_direct>(output); synchronize_threads_in_block(); @@ -81,19 +74,14 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in, Torus *glwe_sub_mask = (Torus *)selected_memory; Torus *glwe_sub_body = (Torus *)glwe_sub_mask + (ptrdiff_t)polynomial_size; - int16_t *glwe_mask_decomposed = (int16_t *)(glwe_sub_body + polynomial_size); - int16_t *glwe_body_decomposed = - (int16_t *)glwe_mask_decomposed + (ptrdiff_t)polynomial_size; - - double2 *mask_res_fft = (double2 *)(glwe_body_decomposed + polynomial_size); + double2 *mask_res_fft = (double2 *)glwe_sub_body + + polynomial_size / (sizeof(double2) / sizeof(Torus)); double2 *body_res_fft = (double2 *)mask_res_fft + (ptrdiff_t)polynomial_size / 2; double2 *glwe_fft = (double2 *)body_res_fft + (ptrdiff_t)(polynomial_size / 2); - GadgetMatrix gadget(base_log, level_count); - ///////////////////////////////////// // glwe2-glwe1 @@ -125,18 +113,18 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in, pos += params::degree / params::opt; } + GadgetMatrix gadget_mask(base_log, level_count, glwe_sub_mask); + GadgetMatrix gadget_body(base_log, level_count, glwe_sub_body); // Subtract each glwe operand, decompose the resulting // polynomial coefficients to multiply each decomposed level // with the corresponding part of the LUT - for (int level = 0; level < level_count; level++) { + for (int level = level_count - 1; level >= 0; level--) { // Decomposition - gadget.decompose_one_level(glwe_mask_decomposed, glwe_sub_mask, level); - gadget.decompose_one_level(glwe_body_decomposed, glwe_sub_body, level); + gadget_mask.decompose_and_compress_next(glwe_fft); // First, perform the polynomial multiplication for the mask - synchronize_threads_in_block(); - fft(glwe_fft, glwe_mask_decomposed); + fft(glwe_fft); // External product and accumulate // Get the piece necessary for the multiplication @@ -157,7 +145,9 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in, // Now handle the polynomial multiplication for the body // in the same way synchronize_threads_in_block(); - fft(glwe_fft, glwe_body_decomposed); + + gadget_body.decompose_and_compress_next(glwe_fft); + fft(glwe_fft); // External product and accumulate // Get the piece necessary for the multiplication @@ -272,8 +262,6 @@ void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in, int memory_needed_per_block = sizeof(Torus) * polynomial_size + // glwe_sub_mask sizeof(Torus) * polynomial_size + // glwe_sub_body - sizeof(int16_t) * polynomial_size + // glwe_mask_decomposed - sizeof(int16_t) * polynomial_size + // glwe_body_decomposed sizeof(double2) * polynomial_size / 2 + // mask_res_fft sizeof(double2) * polynomial_size / 2 + // body_res_fft sizeof(double2) * polynomial_size / 2; // glwe_fft @@ -331,11 +319,11 @@ void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in, // walks horizontally through the leafs if (max_shared_memory < memory_needed_per_block) device_batch_cmux - <<>>( - output, input, d_ggsw_fft_in, d_mem, memory_needed_per_block, - glwe_dimension, // k - polynomial_size, base_log, level_count, - layer_idx // r + <<>>(output, input, d_ggsw_fft_in, d_mem, + memory_needed_per_block, + glwe_dimension, // k + polynomial_size, base_log, level_count, + layer_idx // r ); else device_batch_cmux @@ -640,8 +628,6 @@ void host_blind_rotate_and_sample_extraction( sizeof(Torus) * polynomial_size + // accumulator_c1 body sizeof(Torus) * polynomial_size + // glwe_sub_mask sizeof(Torus) * polynomial_size + // glwe_sub_body - sizeof(int16_t) * polynomial_size + // glwe_mask_decomposed - sizeof(int16_t) * polynomial_size + // glwe_body_decomposed sizeof(double2) * polynomial_size / 2 + // mask_res_fft sizeof(double2) * polynomial_size / 2 + // body_res_fft sizeof(double2) * polynomial_size / 2; // glwe_fft diff --git a/src/crypto/gadget.cuh b/src/crypto/gadget.cuh index 960fe962e..a375acba0 100644 --- a/src/crypto/gadget.cuh +++ b/src/crypto/gadget.cuh @@ -12,51 +12,51 @@ private: uint32_t mask; uint32_t halfbg; T offset; + int current_level; + T mask_mod_b; + T *state; public: - __device__ GadgetMatrix(uint32_t base_log, uint32_t level_count) - : base_log(base_log), level_count(level_count) { - uint32_t bg = 1 << base_log; - this->halfbg = bg / 2; - this->mask = bg - 1; - T temp = 0; - for (int i = 0; i < this->level_count; i++) { - temp += 1ULL << (sizeof(T) * 8 - (i + 1) * this->base_log); - } - this->offset = temp * this->halfbg; - } + __device__ GadgetMatrix(uint32_t base_log, uint32_t level_count, T *state) + : base_log(base_log), level_count(level_count), state(state) { - template - __device__ void decompose_one_level(Polynomial &result, - Polynomial &polynomial, - uint32_t level) { + mask_mod_b = (1ll << base_log) - 1ll; + current_level = level_count; int tid = threadIdx.x; for (int i = 0; i < params::opt; i++) { - T s = polynomial.coefficients[tid] + this->offset; - uint32_t decal = (sizeof(T) * 8 - (level + 1) * this->base_log); - T temp1 = (s >> decal) & this->mask; - result.coefficients[tid] = (V)(temp1 - this->halfbg); - tid += params::degree / params::opt; - } - } - template - __device__ void decompose_one_level(V *result, U *polynomial, - uint32_t level) { - int tid = threadIdx.x; - for (int i = 0; i < params::opt; i++) { - T s = polynomial[tid] + this->offset; - uint32_t decal = (sizeof(T) * 8 - (level + 1) * this->base_log); - T temp1 = (s >> decal) & this->mask; - result[tid] = (V)(temp1 - this->halfbg); + state[tid] >>= (sizeof(T) * 8 - base_log * level_count); tid += params::degree / params::opt; } + synchronize_threads_in_block(); } - __device__ T decompose_one_level_single(T element, uint32_t level) { - T s = element + this->offset; - uint32_t decal = (sizeof(T) * 8 - (level + 1) * this->base_log); - T temp1 = (s >> decal) & this->mask; - return (T)(temp1 - this->halfbg); + __device__ void decompose_and_compress_next(double2 *result) { + int tid = threadIdx.x; + current_level -= 1; + for (int i = 0; i < params::opt / 2; i++) { + T res_re = state[tid * 2] & mask_mod_b; + T res_im = state[tid * 2 + 1] & mask_mod_b; + state[tid * 2] >>= base_log; + state[tid * 2 + 1] >>= base_log; + T carry_re = ((res_re - 1ll) | state[tid * 2]) & res_re; + T carry_im = ((res_im - 1ll) | state[tid * 2 + 1]) & res_im; + carry_re >>= (base_log - 1); + carry_im >>= (base_log - 1); + state[tid * 2] += carry_re; + state[tid * 2 + 1] += carry_im; + res_re -= carry_re << base_log; + res_im -= carry_im << base_log; + + result[tid].x = (int32_t)res_re; + result[tid].y = (int32_t)res_im; + + tid += params::degree / params::opt; + } + synchronize_threads_in_block(); + } + __device__ void decompose_and_compress_level(double2 *result, int level) { + for (int i = 0; i < level_count - level; i++) + decompose_and_compress_next(result); } }; @@ -89,4 +89,15 @@ public: } }; +template +__device__ Torus decompose_one(Torus &state, Torus mask_mod_b, int base_log) { + Torus res = state & mask_mod_b; + state >>= base_log; + Torus carry = ((res - 1ll) | state) & res; + carry >>= base_log - 1; + state += carry; + res -= carry << base_log; + return res; +} + #endif // CNCRT_CRPYTO_H diff --git a/src/keyswitch.cuh b/src/keyswitch.cuh index 6c1afcc78..efc95f7a5 100644 --- a/src/keyswitch.cuh +++ b/src/keyswitch.cuh @@ -17,17 +17,6 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level, return ptr; } -template -__device__ Torus decompose_one(Torus &state, Torus mod_b_mask, int base_log) { - Torus res = state & mod_b_mask; - state >>= base_log; - Torus carry = ((res - 1ll) | state) & res; - carry >>= base_log - 1; - state += carry; - res -= carry << base_log; - return res; -} - /* * keyswitch kernel * Each thread handles a piece of the following equation: @@ -85,12 +74,12 @@ __global__ void keyswitch(Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk, round_to_closest_multiple(block_lwe_array_in[i], base_log, level_count); Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count); - Torus mod_b_mask = (1ll << base_log) - 1ll; + Torus mask_mod_b = (1ll << base_log) - 1ll; for (int j = 0; j < level_count; j++) { auto ksk_block = get_ith_block(ksk, i, level_count - j - 1, lwe_dimension_out, level_count); - Torus decomposed = decompose_one(state, mod_b_mask, base_log); + Torus decomposed = decompose_one(state, mask_mod_b, base_log); for (int k = 0; k < lwe_part_per_thd; k++) { int idx = tid + k * blockDim.x; local_lwe_array_out[idx] -= (Torus)ksk_block[idx] * decomposed;