feat(cuda): add extract bits feature in concrete-cuda

- also, update decomposition algorithm for concrete-cuda keyswitch
This commit is contained in:
Beka Barbakadze
2022-09-25 23:03:58 +04:00
committed by Agnès Leroy
parent 26f26a2132
commit 01ea1cf2f2
5 changed files with 399 additions and 5 deletions

View File

@@ -102,6 +102,55 @@ void cuda_cmux_tree_64(
uint32_t l_gadget,
uint32_t r,
uint32_t max_shared_memory);
void cuda_extract_bits_32(
void *v_stream,
void *list_lwe_out,
void *lwe_in,
void *lwe_in_buffer,
void *lwe_in_shifted_buffer,
void *lwe_out_ks_buffer,
void *lwe_out_pbs_buffer,
void *lut_pbs,
void *lut_vector_indexes,
void *ksk,
void *fourier_bsk,
uint32_t number_of_bits,
uint32_t delta_log,
uint32_t lwe_dimension_before,
uint32_t lwe_dimension_after,
uint32_t base_log_bsk,
uint32_t l_gadget_bsk,
uint32_t base_log_ksk,
uint32_t l_gadget_ksk,
uint32_t number_of_samples);
void cuda_extract_bits_64(
void *v_stream,
void *list_lwe_out,
void *lwe_in,
void *lwe_in_buffer,
void *lwe_in_shifted_buffer,
void *lwe_out_ks_buffer,
void *lwe_out_pbs_buffer,
void *lut_pbs,
void *lut_vector_indexes,
void *ksk,
void *fourier_bsk,
uint32_t number_of_bits,
uint32_t delta_log,
uint32_t lwe_dimension_before,
uint32_t lwe_dimension_after,
uint32_t base_log_bsk,
uint32_t l_gadget_bsk,
uint32_t base_log_ksk,
uint32_t l_gadget_ksk,
uint32_t number_of_samples);
};
#ifdef __CUDACC__

View File

@@ -100,4 +100,127 @@ void cuda_cmux_tree_64(
max_shared_memory);
break;
}
}
}
void cuda_extract_bits_32(
void *v_stream,
void *list_lwe_out,
void *lwe_in,
void *lwe_in_buffer,
void *lwe_in_shifted_buffer,
void *lwe_out_ks_buffer,
void *lwe_out_pbs_buffer,
void *lut_pbs,
void *lut_vector_indexes,
void *ksk,
void *fourier_bsk,
uint32_t number_of_bits,
uint32_t delta_log,
uint32_t lwe_dimension_before,
uint32_t lwe_dimension_after,
uint32_t base_log_bsk,
uint32_t l_gadget_bsk,
uint32_t base_log_ksk,
uint32_t l_gadget_ksk,
uint32_t number_of_samples)
{
switch (lwe_dimension_before) {
case 512:
host_extract_bits<uint32_t, Degree<512>>(
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log,
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
base_log_ksk, l_gadget_ksk, number_of_samples);
break;
case 1024:
host_extract_bits<uint32_t, Degree<1024>>(
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log,
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
base_log_ksk, l_gadget_ksk, number_of_samples);
break;
case 2048:
host_extract_bits<uint32_t, Degree<2048>>(
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log,
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
base_log_ksk, l_gadget_ksk, number_of_samples);
break;
default:
break;
}
}
void cuda_extract_bits_64(
void *v_stream,
void *list_lwe_out,
void *lwe_in,
void *lwe_in_buffer,
void *lwe_in_shifted_buffer,
void *lwe_out_ks_buffer,
void *lwe_out_pbs_buffer,
void *lut_pbs,
void *lut_vector_indexes,
void *ksk,
void *fourier_bsk,
uint32_t number_of_bits,
uint32_t delta_log,
uint32_t lwe_dimension_before,
uint32_t lwe_dimension_after,
uint32_t base_log_bsk,
uint32_t l_gadget_bsk,
uint32_t base_log_ksk,
uint32_t l_gadget_ksk,
uint32_t number_of_samples)
{
switch (lwe_dimension_before) {
case 512:
host_extract_bits<uint64_t, Degree<512>>(
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log,
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
base_log_ksk, l_gadget_ksk, number_of_samples);
break;
case 1024:
host_extract_bits<uint64_t, Degree<1024>>(
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log,
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
base_log_ksk, l_gadget_ksk, number_of_samples);
break;
case 2048:
host_extract_bits<uint64_t, Degree<2048>>(
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log,
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
base_log_ksk, l_gadget_ksk, number_of_samples);
break;
default:
break;
}
}

View File

@@ -6,9 +6,7 @@
#include "../include/helper_cuda.h"
#include "bootstrap.h"
#include "complex/operations.cuh"
#include "crypto/gadget.cuh"
#include "crypto/torus.cuh"
#include "crypto/ggsw.cuh"
#include "fft/bnsmfft.cuh"
#include "fft/smfft.cuh"
#include "fft/twiddles.cuh"
@@ -18,6 +16,9 @@
#include "polynomial/polynomial_math.cuh"
#include "utils/memory.cuh"
#include "utils/timer.cuh"
#include "keyswitch.cuh"
#include "bootstrap_low_latency.cuh"
#include "crypto/ggsw.cuh"
template <typename T, class params>
__device__ void fft(double2 *output, T *input){
@@ -399,4 +400,206 @@ void host_cmux_tree(
checkCudaErrors(cudaFree(d_mem));
}
// only works for big lwe for ks+bs case
// state_lwe_buffer is copied from big lwe input
// shifted_lwe_buffer is scalar multiplication of lwe input
// blockIdx.x refers to input ciphertext id
template <typename Torus, class params>
__global__ void copy_and_shift_lwe(Torus *dst_copy, Torus *dst_shift,
Torus *src, Torus value)
{
int blockId = blockIdx.x;
int tid = threadIdx.x;
auto cur_dst_copy = &dst_copy[blockId * (params::degree + 1)];
auto cur_dst_shift = &dst_shift[blockId * (params::degree + 1)];
auto cur_src = &src[blockId * (params::degree + 1)];
#pragma unroll
for (int i = 0; i < params::opt; i++) {
cur_dst_copy[tid] = cur_src[tid];
cur_dst_shift[tid] = cur_src[tid] * value;
tid += params::degree / params::opt;
}
if (threadIdx.x == params::degree / params::opt - 1) {
cur_dst_copy[params::degree] = cur_src[params::degree];
cur_dst_shift[params::degree] = cur_src[params::degree] * value;
}
}
// only works for small lwe in ks+bs case
// function copies lwe when length is not a power of two
template <typename Torus>
__global__ void copy_small_lwe(Torus *dst, Torus *src, uint32_t small_lwe_size, uint32_t number_of_bits,
uint32_t lwe_id)
{
size_t blockId = blockIdx.x;
size_t threads_per_block = blockDim.x;
size_t opt = small_lwe_size / threads_per_block;
size_t rem = small_lwe_size & (threads_per_block - 1);
auto cur_lwe_list = &dst[blockId * small_lwe_size * number_of_bits];
auto cur_dst = &cur_lwe_list[lwe_id * small_lwe_size];
auto cur_src = &src[blockId * small_lwe_size];
size_t tid = threadIdx.x;
for (int i = 0; i < opt; i++) {
cur_dst[tid] = cur_src[tid];
tid += threads_per_block;
}
if (threadIdx.x < rem)
cur_dst[tid] = cur_src[tid];
}
// only used in extract bits for one ciphertext
// should be called with one block and one thread
// NOTE: check if putting this functionality in copy_small_lwe or
// fill_pbs_lut vector is faster
template <typename Torus>
__global__ void add_to_body(Torus *lwe, size_t lwe_dimension,
Torus value) {
lwe[blockIdx.x * (lwe_dimension + 1) + lwe_dimension] += value;
}
// Fill lut(only body) for the current bit (equivalent to trivial encryption as
// mask is 0s)
// The LUT is filled with -alpha in each coefficient where alpha = delta*2^{bit_idx-1}
template <typename Torus, class params>
__global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value)
{
Torus *cur_poly = &lut[params::degree];
size_t tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
cur_poly[tid] = value;
tid += params::degree / params::opt;
}
}
// Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0 if the
// extracted bit was 0 and 1 in the other case
//
// Remove the extracted bit from the state LWE to get a 0 at the extracted bit
// location.
//
// Shift on padding bit for next iteration, that's why
// alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 2) is used
// instead of alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 1)
template <typename Torus, class params>
__global__ void add_sub_and_mul_lwe(Torus *shifted_lwe, Torus *state_lwe,
Torus *pbs_lwe_out, Torus add_value,
Torus mul_value)
{
size_t tid = threadIdx.x;
size_t blockId = blockIdx.x;
auto cur_shifted_lwe = &shifted_lwe[blockId * (params::degree + 1)];
auto cur_state_lwe = &state_lwe[blockId * (params::degree + 1)];
auto cur_pbs_lwe_out = &pbs_lwe_out[blockId * (params::degree + 1)];
#pragma unroll
for (int i = 0; i < params::opt; i++) {
cur_shifted_lwe[tid] = cur_state_lwe[tid] -= cur_pbs_lwe_out[tid];
cur_shifted_lwe[tid] *= mul_value;
tid += params::degree / params::opt;
}
if (threadIdx.x == params::degree / params::opt - 1) {
cur_shifted_lwe[params::degree] = cur_state_lwe[params::degree] -=
(cur_pbs_lwe_out[params::degree] + add_value);
cur_shifted_lwe[params::degree] *= mul_value;
}
}
template <typename Torus, class params>
__host__ void host_extract_bits(
void *v_stream,
Torus *list_lwe_out,
Torus *lwe_in,
Torus *lwe_in_buffer,
Torus *lwe_in_shifted_buffer,
Torus *lwe_out_ks_buffer,
Torus *lwe_out_pbs_buffer,
Torus *lut_pbs,
uint32_t *lut_vector_indexes,
Torus *ksk,
double2 *fourier_bsk,
uint32_t number_of_bits,
uint32_t delta_log,
uint32_t lwe_dimension_before,
uint32_t lwe_dimension_after,
uint32_t base_log_bsk,
uint32_t l_gadget_bsk,
uint32_t base_log_ksk,
uint32_t l_gadget_ksk,
uint32_t number_of_samples)
{
auto stream = static_cast<cudaStream_t *>(v_stream);
uint32_t ciphertext_n_bits = sizeof(Torus) * 8;
int blocks = 1;
int threads = params::degree / params::opt;
copy_and_shift_lwe<Torus, params><<<blocks, threads, 0, *stream>>>
(lwe_in_buffer, lwe_in_shifted_buffer, lwe_in,
1ll << (ciphertext_n_bits - delta_log - 1));
for (int bit_idx = 0; bit_idx < number_of_bits; bit_idx++) {
cuda_keyswitch_lwe_ciphertext_vector(v_stream, lwe_out_ks_buffer,
lwe_in_shifted_buffer, ksk,
lwe_dimension_before,
lwe_dimension_after, base_log_ksk,
l_gadget_ksk, 1);
copy_small_lwe<<<1, 256, 0, *stream>>>(list_lwe_out,
lwe_out_ks_buffer,
lwe_dimension_after + 1,
number_of_bits,
number_of_bits - bit_idx - 1);
if (bit_idx == number_of_bits - 1) {
break;
}
add_to_body<Torus><<<1, 1, 0, *stream>>>(lwe_out_ks_buffer,
lwe_dimension_after,
1ll << (ciphertext_n_bits - 2));
fill_lut_body_for_current_bit<Torus, params>
<<<blocks, threads, 0,*stream>>> (lut_pbs, 0ll - 1ll << (
delta_log - 1 +
bit_idx));
host_bootstrap_low_latency<Torus, params>(v_stream, lwe_out_pbs_buffer,
lut_pbs, lut_vector_indexes,
lwe_out_ks_buffer, fourier_bsk,
lwe_dimension_after, lwe_dimension_before,
base_log_bsk, l_gadget_bsk, number_of_samples,
1);
add_sub_and_mul_lwe<Torus, params><<<1, threads, 0, *stream>>>(
lwe_in_shifted_buffer, lwe_in_buffer, lwe_out_pbs_buffer,
1ll << (delta_log - 1 + bit_idx),
1ll << (ciphertext_n_bits - delta_log - bit_idx - 2) );
}
}
#endif //WO_PBS_H

View File

@@ -4,6 +4,7 @@
#include "polynomial/polynomial.cuh"
#include <cstdint>
#pragma once
template <typename T, class params> class GadgetMatrix {
private:
uint32_t l_gadget;

View File

@@ -17,6 +17,18 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
return ptr;
}
template <typename Torus>
__device__ Torus decompose_one(Torus &state, Torus mod_b_mask,
int base_log) {
Torus res = state & mod_b_mask;
state >>= base_log;
Torus carry = ((res - 1ll) | state) & res;
carry >>= base_log - 1;
state += carry;
res -= carry << base_log;
return res;
}
/*
* keyswitch kernel
* Each thread handles a piece of the following equation:
@@ -76,9 +88,15 @@ __global__ void keyswitch(Torus *lwe_out, Torus *lwe_in,
Torus a_i = round_to_closest_multiple(block_lwe_in[i], base_log,
l_gadget);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * l_gadget);
Torus mod_b_mask = (1ll << base_log) - 1ll;
for (int j = 0; j < l_gadget; j++) {
auto ksk_block = get_ith_block(ksk, i, j, lwe_dimension_after, l_gadget);
Torus decomposed = gadget.decompose_one_level_single(a_i, (uint32_t)j);
auto ksk_block = get_ith_block(ksk, i, l_gadget - j - 1,
lwe_dimension_after,
l_gadget);
Torus decomposed = decompose_one<Torus>(state, mod_b_mask,
base_log);
for (int k = 0; k < lwe_part_per_thd; k++) {
int idx = tid + k * blockDim.x;
local_lwe_out[idx] -= (Torus)ksk_block[idx] * decomposed;