From 01ea1cf2f22324c5de5ba092f16779682ebd76b6 Mon Sep 17 00:00:00 2001 From: Beka Barbakadze Date: Sun, 25 Sep 2022 23:03:58 +0400 Subject: [PATCH] feat(cuda): add extract bits feature in concrete-cuda - also, update decomposition algorithm for concrete-cuda keyswitch --- include/bootstrap.h | 49 ++++++++++ src/bootstrap_wop.cu | 125 ++++++++++++++++++++++++- src/bootstrap_wop.cuh | 207 +++++++++++++++++++++++++++++++++++++++++- src/crypto/gadget.cuh | 1 + src/keyswitch.cuh | 22 ++++- 5 files changed, 399 insertions(+), 5 deletions(-) diff --git a/include/bootstrap.h b/include/bootstrap.h index c17c196b6..d56cd9d50 100644 --- a/include/bootstrap.h +++ b/include/bootstrap.h @@ -102,6 +102,55 @@ void cuda_cmux_tree_64( uint32_t l_gadget, uint32_t r, uint32_t max_shared_memory); + + + +void cuda_extract_bits_32( + void *v_stream, + void *list_lwe_out, + void *lwe_in, + void *lwe_in_buffer, + void *lwe_in_shifted_buffer, + void *lwe_out_ks_buffer, + void *lwe_out_pbs_buffer, + void *lut_pbs, + void *lut_vector_indexes, + void *ksk, + void *fourier_bsk, + uint32_t number_of_bits, + uint32_t delta_log, + uint32_t lwe_dimension_before, + uint32_t lwe_dimension_after, + uint32_t base_log_bsk, + uint32_t l_gadget_bsk, + uint32_t base_log_ksk, + uint32_t l_gadget_ksk, + uint32_t number_of_samples); + + +void cuda_extract_bits_64( + void *v_stream, + void *list_lwe_out, + void *lwe_in, + void *lwe_in_buffer, + void *lwe_in_shifted_buffer, + void *lwe_out_ks_buffer, + void *lwe_out_pbs_buffer, + void *lut_pbs, + void *lut_vector_indexes, + void *ksk, + void *fourier_bsk, + uint32_t number_of_bits, + uint32_t delta_log, + uint32_t lwe_dimension_before, + uint32_t lwe_dimension_after, + uint32_t base_log_bsk, + uint32_t l_gadget_bsk, + uint32_t base_log_ksk, + uint32_t l_gadget_ksk, + uint32_t number_of_samples); + + }; #ifdef __CUDACC__ diff --git a/src/bootstrap_wop.cu b/src/bootstrap_wop.cu index a1da727e6..05453f56e 100644 --- a/src/bootstrap_wop.cu +++ b/src/bootstrap_wop.cu @@ -100,4 +100,127 @@ void cuda_cmux_tree_64( max_shared_memory); break; } -} \ No newline at end of file +} + + +void cuda_extract_bits_32( + void *v_stream, + void *list_lwe_out, + void *lwe_in, + void *lwe_in_buffer, + void *lwe_in_shifted_buffer, + void *lwe_out_ks_buffer, + void *lwe_out_pbs_buffer, + void *lut_pbs, + void *lut_vector_indexes, + void *ksk, + void *fourier_bsk, + uint32_t number_of_bits, + uint32_t delta_log, + uint32_t lwe_dimension_before, + uint32_t lwe_dimension_after, + uint32_t base_log_bsk, + uint32_t l_gadget_bsk, + uint32_t base_log_ksk, + uint32_t l_gadget_ksk, + uint32_t number_of_samples) +{ + switch (lwe_dimension_before) { + case 512: + host_extract_bits>( + v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in, + (uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer, + (uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer, + (uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk, + (double2 *)fourier_bsk, number_of_bits, delta_log, + lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk, + base_log_ksk, l_gadget_ksk, number_of_samples); + break; + case 1024: + host_extract_bits>( + v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in, + (uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer, + (uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer, + (uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk, + (double2 *)fourier_bsk, number_of_bits, delta_log, + lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk, + base_log_ksk, l_gadget_ksk, number_of_samples); + break; + case 2048: + host_extract_bits>( + v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in, + (uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer, + (uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer, + (uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk, + (double2 *)fourier_bsk, number_of_bits, delta_log, + lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk, + base_log_ksk, l_gadget_ksk, number_of_samples); + break; + default: + break; + } + +} + + + +void cuda_extract_bits_64( + void *v_stream, + void *list_lwe_out, + void *lwe_in, + void *lwe_in_buffer, + void *lwe_in_shifted_buffer, + void *lwe_out_ks_buffer, + void *lwe_out_pbs_buffer, + void *lut_pbs, + void *lut_vector_indexes, + void *ksk, + void *fourier_bsk, + uint32_t number_of_bits, + uint32_t delta_log, + uint32_t lwe_dimension_before, + uint32_t lwe_dimension_after, + uint32_t base_log_bsk, + uint32_t l_gadget_bsk, + uint32_t base_log_ksk, + uint32_t l_gadget_ksk, + uint32_t number_of_samples) +{ + switch (lwe_dimension_before) { + case 512: + host_extract_bits>( + v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in, + (uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer, + (uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer, + (uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk, + (double2 *)fourier_bsk, number_of_bits, delta_log, + lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk, + base_log_ksk, l_gadget_ksk, number_of_samples); + break; + case 1024: + host_extract_bits>( + v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in, + (uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer, + (uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer, + (uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk, + (double2 *)fourier_bsk, number_of_bits, delta_log, + lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk, + base_log_ksk, l_gadget_ksk, number_of_samples); + break; + case 2048: + host_extract_bits>( + v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in, + (uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer, + (uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer, + (uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk, + (double2 *)fourier_bsk, number_of_bits, delta_log, + lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk, + base_log_ksk, l_gadget_ksk, number_of_samples); + break; + default: + break; + } + +} + + diff --git a/src/bootstrap_wop.cuh b/src/bootstrap_wop.cuh index 44176cac6..b0c9e44ee 100644 --- a/src/bootstrap_wop.cuh +++ b/src/bootstrap_wop.cuh @@ -6,9 +6,7 @@ #include "../include/helper_cuda.h" #include "bootstrap.h" #include "complex/operations.cuh" -#include "crypto/gadget.cuh" #include "crypto/torus.cuh" -#include "crypto/ggsw.cuh" #include "fft/bnsmfft.cuh" #include "fft/smfft.cuh" #include "fft/twiddles.cuh" @@ -18,6 +16,9 @@ #include "polynomial/polynomial_math.cuh" #include "utils/memory.cuh" #include "utils/timer.cuh" +#include "keyswitch.cuh" +#include "bootstrap_low_latency.cuh" +#include "crypto/ggsw.cuh" template __device__ void fft(double2 *output, T *input){ @@ -399,4 +400,206 @@ void host_cmux_tree( checkCudaErrors(cudaFree(d_mem)); } + + + +// only works for big lwe for ks+bs case +// state_lwe_buffer is copied from big lwe input +// shifted_lwe_buffer is scalar multiplication of lwe input +// blockIdx.x refers to input ciphertext id +template +__global__ void copy_and_shift_lwe(Torus *dst_copy, Torus *dst_shift, + Torus *src, Torus value) +{ + int blockId = blockIdx.x; + int tid = threadIdx.x; + auto cur_dst_copy = &dst_copy[blockId * (params::degree + 1)]; + auto cur_dst_shift = &dst_shift[blockId * (params::degree + 1)]; + auto cur_src = &src[blockId * (params::degree + 1)]; + +#pragma unroll + for (int i = 0; i < params::opt; i++) { + cur_dst_copy[tid] = cur_src[tid]; + cur_dst_shift[tid] = cur_src[tid] * value; + tid += params::degree / params::opt; + } + + if (threadIdx.x == params::degree / params::opt - 1) { + cur_dst_copy[params::degree] = cur_src[params::degree]; + cur_dst_shift[params::degree] = cur_src[params::degree] * value; + } +} + + +// only works for small lwe in ks+bs case +// function copies lwe when length is not a power of two +template +__global__ void copy_small_lwe(Torus *dst, Torus *src, uint32_t small_lwe_size, uint32_t number_of_bits, + uint32_t lwe_id) +{ + + + + size_t blockId = blockIdx.x; + size_t threads_per_block = blockDim.x; + size_t opt = small_lwe_size / threads_per_block; + size_t rem = small_lwe_size & (threads_per_block - 1); + + auto cur_lwe_list = &dst[blockId * small_lwe_size * number_of_bits]; + auto cur_dst = &cur_lwe_list[lwe_id * small_lwe_size]; + auto cur_src = &src[blockId * small_lwe_size]; + + size_t tid = threadIdx.x; + for (int i = 0; i < opt; i++) { + cur_dst[tid] = cur_src[tid]; + tid += threads_per_block; + } + + if (threadIdx.x < rem) + cur_dst[tid] = cur_src[tid]; + + +} + + +// only used in extract bits for one ciphertext +// should be called with one block and one thread +// NOTE: check if putting this functionality in copy_small_lwe or +// fill_pbs_lut vector is faster +template +__global__ void add_to_body(Torus *lwe, size_t lwe_dimension, + Torus value) { + lwe[blockIdx.x * (lwe_dimension + 1) + lwe_dimension] += value; + +} + + + +// Fill lut(only body) for the current bit (equivalent to trivial encryption as +// mask is 0s) +// The LUT is filled with -alpha in each coefficient where alpha = delta*2^{bit_idx-1} +template +__global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) +{ + Torus *cur_poly = &lut[params::degree]; + size_t tid = threadIdx.x; +#pragma unroll + for (int i = 0; i < params::opt; i++) { + cur_poly[tid] = value; + tid += params::degree / params::opt; + } +} + + + +// Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0 if the +// extracted bit was 0 and 1 in the other case +// +// Remove the extracted bit from the state LWE to get a 0 at the extracted bit +// location. +// +// Shift on padding bit for next iteration, that's why +// alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 2) is used +// instead of alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 1) +template +__global__ void add_sub_and_mul_lwe(Torus *shifted_lwe, Torus *state_lwe, + Torus *pbs_lwe_out, Torus add_value, + Torus mul_value) +{ + size_t tid = threadIdx.x; + size_t blockId = blockIdx.x; + auto cur_shifted_lwe = &shifted_lwe[blockId * (params::degree + 1)]; + auto cur_state_lwe = &state_lwe[blockId * (params::degree + 1)]; + auto cur_pbs_lwe_out = &pbs_lwe_out[blockId * (params::degree + 1)]; +#pragma unroll + for (int i = 0; i < params::opt; i++) { + cur_shifted_lwe[tid] = cur_state_lwe[tid] -= cur_pbs_lwe_out[tid]; + cur_shifted_lwe[tid] *= mul_value; + tid += params::degree / params::opt; + } + + if (threadIdx.x == params::degree / params::opt - 1) { + cur_shifted_lwe[params::degree] = cur_state_lwe[params::degree] -= + (cur_pbs_lwe_out[params::degree] + add_value); + cur_shifted_lwe[params::degree] *= mul_value; + } +} + + +template +__host__ void host_extract_bits( + void *v_stream, + Torus *list_lwe_out, + Torus *lwe_in, + Torus *lwe_in_buffer, + Torus *lwe_in_shifted_buffer, + Torus *lwe_out_ks_buffer, + Torus *lwe_out_pbs_buffer, + Torus *lut_pbs, + uint32_t *lut_vector_indexes, + Torus *ksk, + double2 *fourier_bsk, + uint32_t number_of_bits, + uint32_t delta_log, + uint32_t lwe_dimension_before, + uint32_t lwe_dimension_after, + uint32_t base_log_bsk, + uint32_t l_gadget_bsk, + uint32_t base_log_ksk, + uint32_t l_gadget_ksk, + uint32_t number_of_samples) +{ + auto stream = static_cast(v_stream); + uint32_t ciphertext_n_bits = sizeof(Torus) * 8; + + int blocks = 1; + int threads = params::degree / params::opt; + + copy_and_shift_lwe<<>> + (lwe_in_buffer, lwe_in_shifted_buffer, lwe_in, + 1ll << (ciphertext_n_bits - delta_log - 1)); + + for (int bit_idx = 0; bit_idx < number_of_bits; bit_idx++) { + cuda_keyswitch_lwe_ciphertext_vector(v_stream, lwe_out_ks_buffer, + lwe_in_shifted_buffer, ksk, + lwe_dimension_before, + lwe_dimension_after, base_log_ksk, + l_gadget_ksk, 1); + + copy_small_lwe<<<1, 256, 0, *stream>>>(list_lwe_out, + lwe_out_ks_buffer, + lwe_dimension_after + 1, + number_of_bits, + number_of_bits - bit_idx - 1); + + if (bit_idx == number_of_bits - 1) { + break; + } + + add_to_body<<<1, 1, 0, *stream>>>(lwe_out_ks_buffer, + lwe_dimension_after, + 1ll << (ciphertext_n_bits - 2)); + + + fill_lut_body_for_current_bit + <<>> (lut_pbs, 0ll - 1ll << ( + delta_log - 1 + + bit_idx)); + + host_bootstrap_low_latency(v_stream, lwe_out_pbs_buffer, + lut_pbs, lut_vector_indexes, + lwe_out_ks_buffer, fourier_bsk, + lwe_dimension_after, lwe_dimension_before, + base_log_bsk, l_gadget_bsk, number_of_samples, + 1); + + add_sub_and_mul_lwe<<<1, threads, 0, *stream>>>( + lwe_in_shifted_buffer, lwe_in_buffer, lwe_out_pbs_buffer, + 1ll << (delta_log - 1 + bit_idx), + 1ll << (ciphertext_n_bits - delta_log - bit_idx - 2) ); + } + +} + + #endif //WO_PBS_H diff --git a/src/crypto/gadget.cuh b/src/crypto/gadget.cuh index 497fd4300..e27c73dba 100644 --- a/src/crypto/gadget.cuh +++ b/src/crypto/gadget.cuh @@ -4,6 +4,7 @@ #include "polynomial/polynomial.cuh" #include +#pragma once template class GadgetMatrix { private: uint32_t l_gadget; diff --git a/src/keyswitch.cuh b/src/keyswitch.cuh index ebf06a6be..877fb26c2 100644 --- a/src/keyswitch.cuh +++ b/src/keyswitch.cuh @@ -17,6 +17,18 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level, return ptr; } +template +__device__ Torus decompose_one(Torus &state, Torus mod_b_mask, + int base_log) { + Torus res = state & mod_b_mask; + state >>= base_log; + Torus carry = ((res - 1ll) | state) & res; + carry >>= base_log - 1; + state += carry; + res -= carry << base_log; + return res; +} + /* * keyswitch kernel * Each thread handles a piece of the following equation: @@ -76,9 +88,15 @@ __global__ void keyswitch(Torus *lwe_out, Torus *lwe_in, Torus a_i = round_to_closest_multiple(block_lwe_in[i], base_log, l_gadget); + Torus state = a_i >> (sizeof(Torus) * 8 - base_log * l_gadget); + Torus mod_b_mask = (1ll << base_log) - 1ll; + for (int j = 0; j < l_gadget; j++) { - auto ksk_block = get_ith_block(ksk, i, j, lwe_dimension_after, l_gadget); - Torus decomposed = gadget.decompose_one_level_single(a_i, (uint32_t)j); + auto ksk_block = get_ith_block(ksk, i, l_gadget - j - 1, + lwe_dimension_after, + l_gadget); + Torus decomposed = decompose_one(state, mod_b_mask, + base_log); for (int k = 0; k < lwe_part_per_thd; k++) { int idx = tid + k * blockDim.x; local_lwe_out[idx] -= (Torus)ksk_block[idx] * decomposed;