mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(cuda): add extract bits feature in concrete-cuda
- also, update decomposition algorithm for concrete-cuda keyswitch
This commit is contained in:
committed by
Agnès Leroy
parent
26f26a2132
commit
01ea1cf2f2
@@ -102,6 +102,55 @@ void cuda_cmux_tree_64(
|
||||
uint32_t l_gadget,
|
||||
uint32_t r,
|
||||
uint32_t max_shared_memory);
|
||||
|
||||
|
||||
|
||||
void cuda_extract_bits_32(
|
||||
void *v_stream,
|
||||
void *list_lwe_out,
|
||||
void *lwe_in,
|
||||
void *lwe_in_buffer,
|
||||
void *lwe_in_shifted_buffer,
|
||||
void *lwe_out_ks_buffer,
|
||||
void *lwe_out_pbs_buffer,
|
||||
void *lut_pbs,
|
||||
void *lut_vector_indexes,
|
||||
void *ksk,
|
||||
void *fourier_bsk,
|
||||
uint32_t number_of_bits,
|
||||
uint32_t delta_log,
|
||||
uint32_t lwe_dimension_before,
|
||||
uint32_t lwe_dimension_after,
|
||||
uint32_t base_log_bsk,
|
||||
uint32_t l_gadget_bsk,
|
||||
uint32_t base_log_ksk,
|
||||
uint32_t l_gadget_ksk,
|
||||
uint32_t number_of_samples);
|
||||
|
||||
|
||||
void cuda_extract_bits_64(
|
||||
void *v_stream,
|
||||
void *list_lwe_out,
|
||||
void *lwe_in,
|
||||
void *lwe_in_buffer,
|
||||
void *lwe_in_shifted_buffer,
|
||||
void *lwe_out_ks_buffer,
|
||||
void *lwe_out_pbs_buffer,
|
||||
void *lut_pbs,
|
||||
void *lut_vector_indexes,
|
||||
void *ksk,
|
||||
void *fourier_bsk,
|
||||
uint32_t number_of_bits,
|
||||
uint32_t delta_log,
|
||||
uint32_t lwe_dimension_before,
|
||||
uint32_t lwe_dimension_after,
|
||||
uint32_t base_log_bsk,
|
||||
uint32_t l_gadget_bsk,
|
||||
uint32_t base_log_ksk,
|
||||
uint32_t l_gadget_ksk,
|
||||
uint32_t number_of_samples);
|
||||
|
||||
|
||||
};
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
@@ -100,4 +100,127 @@ void cuda_cmux_tree_64(
|
||||
max_shared_memory);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void cuda_extract_bits_32(
|
||||
void *v_stream,
|
||||
void *list_lwe_out,
|
||||
void *lwe_in,
|
||||
void *lwe_in_buffer,
|
||||
void *lwe_in_shifted_buffer,
|
||||
void *lwe_out_ks_buffer,
|
||||
void *lwe_out_pbs_buffer,
|
||||
void *lut_pbs,
|
||||
void *lut_vector_indexes,
|
||||
void *ksk,
|
||||
void *fourier_bsk,
|
||||
uint32_t number_of_bits,
|
||||
uint32_t delta_log,
|
||||
uint32_t lwe_dimension_before,
|
||||
uint32_t lwe_dimension_after,
|
||||
uint32_t base_log_bsk,
|
||||
uint32_t l_gadget_bsk,
|
||||
uint32_t base_log_ksk,
|
||||
uint32_t l_gadget_ksk,
|
||||
uint32_t number_of_samples)
|
||||
{
|
||||
switch (lwe_dimension_before) {
|
||||
case 512:
|
||||
host_extract_bits<uint32_t, Degree<512>>(
|
||||
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
|
||||
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
|
||||
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
|
||||
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log,
|
||||
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
|
||||
base_log_ksk, l_gadget_ksk, number_of_samples);
|
||||
break;
|
||||
case 1024:
|
||||
host_extract_bits<uint32_t, Degree<1024>>(
|
||||
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
|
||||
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
|
||||
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
|
||||
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log,
|
||||
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
|
||||
base_log_ksk, l_gadget_ksk, number_of_samples);
|
||||
break;
|
||||
case 2048:
|
||||
host_extract_bits<uint32_t, Degree<2048>>(
|
||||
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
|
||||
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
|
||||
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
|
||||
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log,
|
||||
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
|
||||
base_log_ksk, l_gadget_ksk, number_of_samples);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
void cuda_extract_bits_64(
|
||||
void *v_stream,
|
||||
void *list_lwe_out,
|
||||
void *lwe_in,
|
||||
void *lwe_in_buffer,
|
||||
void *lwe_in_shifted_buffer,
|
||||
void *lwe_out_ks_buffer,
|
||||
void *lwe_out_pbs_buffer,
|
||||
void *lut_pbs,
|
||||
void *lut_vector_indexes,
|
||||
void *ksk,
|
||||
void *fourier_bsk,
|
||||
uint32_t number_of_bits,
|
||||
uint32_t delta_log,
|
||||
uint32_t lwe_dimension_before,
|
||||
uint32_t lwe_dimension_after,
|
||||
uint32_t base_log_bsk,
|
||||
uint32_t l_gadget_bsk,
|
||||
uint32_t base_log_ksk,
|
||||
uint32_t l_gadget_ksk,
|
||||
uint32_t number_of_samples)
|
||||
{
|
||||
switch (lwe_dimension_before) {
|
||||
case 512:
|
||||
host_extract_bits<uint64_t, Degree<512>>(
|
||||
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
|
||||
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
|
||||
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
|
||||
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log,
|
||||
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
|
||||
base_log_ksk, l_gadget_ksk, number_of_samples);
|
||||
break;
|
||||
case 1024:
|
||||
host_extract_bits<uint64_t, Degree<1024>>(
|
||||
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
|
||||
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
|
||||
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
|
||||
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log,
|
||||
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
|
||||
base_log_ksk, l_gadget_ksk, number_of_samples);
|
||||
break;
|
||||
case 2048:
|
||||
host_extract_bits<uint64_t, Degree<2048>>(
|
||||
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
|
||||
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
|
||||
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
|
||||
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log,
|
||||
lwe_dimension_before, lwe_dimension_after, base_log_bsk, l_gadget_bsk,
|
||||
base_log_ksk, l_gadget_ksk, number_of_samples);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,9 +6,7 @@
|
||||
#include "../include/helper_cuda.h"
|
||||
#include "bootstrap.h"
|
||||
#include "complex/operations.cuh"
|
||||
#include "crypto/gadget.cuh"
|
||||
#include "crypto/torus.cuh"
|
||||
#include "crypto/ggsw.cuh"
|
||||
#include "fft/bnsmfft.cuh"
|
||||
#include "fft/smfft.cuh"
|
||||
#include "fft/twiddles.cuh"
|
||||
@@ -18,6 +16,9 @@
|
||||
#include "polynomial/polynomial_math.cuh"
|
||||
#include "utils/memory.cuh"
|
||||
#include "utils/timer.cuh"
|
||||
#include "keyswitch.cuh"
|
||||
#include "bootstrap_low_latency.cuh"
|
||||
#include "crypto/ggsw.cuh"
|
||||
|
||||
template <typename T, class params>
|
||||
__device__ void fft(double2 *output, T *input){
|
||||
@@ -399,4 +400,206 @@ void host_cmux_tree(
|
||||
checkCudaErrors(cudaFree(d_mem));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
// only works for big lwe for ks+bs case
|
||||
// state_lwe_buffer is copied from big lwe input
|
||||
// shifted_lwe_buffer is scalar multiplication of lwe input
|
||||
// blockIdx.x refers to input ciphertext id
|
||||
template <typename Torus, class params>
|
||||
__global__ void copy_and_shift_lwe(Torus *dst_copy, Torus *dst_shift,
|
||||
Torus *src, Torus value)
|
||||
{
|
||||
int blockId = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
auto cur_dst_copy = &dst_copy[blockId * (params::degree + 1)];
|
||||
auto cur_dst_shift = &dst_shift[blockId * (params::degree + 1)];
|
||||
auto cur_src = &src[blockId * (params::degree + 1)];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
cur_dst_copy[tid] = cur_src[tid];
|
||||
cur_dst_shift[tid] = cur_src[tid] * value;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
if (threadIdx.x == params::degree / params::opt - 1) {
|
||||
cur_dst_copy[params::degree] = cur_src[params::degree];
|
||||
cur_dst_shift[params::degree] = cur_src[params::degree] * value;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// only works for small lwe in ks+bs case
|
||||
// function copies lwe when length is not a power of two
|
||||
template <typename Torus>
|
||||
__global__ void copy_small_lwe(Torus *dst, Torus *src, uint32_t small_lwe_size, uint32_t number_of_bits,
|
||||
uint32_t lwe_id)
|
||||
{
|
||||
|
||||
|
||||
|
||||
size_t blockId = blockIdx.x;
|
||||
size_t threads_per_block = blockDim.x;
|
||||
size_t opt = small_lwe_size / threads_per_block;
|
||||
size_t rem = small_lwe_size & (threads_per_block - 1);
|
||||
|
||||
auto cur_lwe_list = &dst[blockId * small_lwe_size * number_of_bits];
|
||||
auto cur_dst = &cur_lwe_list[lwe_id * small_lwe_size];
|
||||
auto cur_src = &src[blockId * small_lwe_size];
|
||||
|
||||
size_t tid = threadIdx.x;
|
||||
for (int i = 0; i < opt; i++) {
|
||||
cur_dst[tid] = cur_src[tid];
|
||||
tid += threads_per_block;
|
||||
}
|
||||
|
||||
if (threadIdx.x < rem)
|
||||
cur_dst[tid] = cur_src[tid];
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
// only used in extract bits for one ciphertext
|
||||
// should be called with one block and one thread
|
||||
// NOTE: check if putting this functionality in copy_small_lwe or
|
||||
// fill_pbs_lut vector is faster
|
||||
template <typename Torus>
|
||||
__global__ void add_to_body(Torus *lwe, size_t lwe_dimension,
|
||||
Torus value) {
|
||||
lwe[blockIdx.x * (lwe_dimension + 1) + lwe_dimension] += value;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Fill lut(only body) for the current bit (equivalent to trivial encryption as
|
||||
// mask is 0s)
|
||||
// The LUT is filled with -alpha in each coefficient where alpha = delta*2^{bit_idx-1}
|
||||
template <typename Torus, class params>
|
||||
__global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value)
|
||||
{
|
||||
Torus *cur_poly = &lut[params::degree];
|
||||
size_t tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
cur_poly[tid] = value;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0 if the
|
||||
// extracted bit was 0 and 1 in the other case
|
||||
//
|
||||
// Remove the extracted bit from the state LWE to get a 0 at the extracted bit
|
||||
// location.
|
||||
//
|
||||
// Shift on padding bit for next iteration, that's why
|
||||
// alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 2) is used
|
||||
// instead of alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 1)
|
||||
template <typename Torus, class params>
|
||||
__global__ void add_sub_and_mul_lwe(Torus *shifted_lwe, Torus *state_lwe,
|
||||
Torus *pbs_lwe_out, Torus add_value,
|
||||
Torus mul_value)
|
||||
{
|
||||
size_t tid = threadIdx.x;
|
||||
size_t blockId = blockIdx.x;
|
||||
auto cur_shifted_lwe = &shifted_lwe[blockId * (params::degree + 1)];
|
||||
auto cur_state_lwe = &state_lwe[blockId * (params::degree + 1)];
|
||||
auto cur_pbs_lwe_out = &pbs_lwe_out[blockId * (params::degree + 1)];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
cur_shifted_lwe[tid] = cur_state_lwe[tid] -= cur_pbs_lwe_out[tid];
|
||||
cur_shifted_lwe[tid] *= mul_value;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
if (threadIdx.x == params::degree / params::opt - 1) {
|
||||
cur_shifted_lwe[params::degree] = cur_state_lwe[params::degree] -=
|
||||
(cur_pbs_lwe_out[params::degree] + add_value);
|
||||
cur_shifted_lwe[params::degree] *= mul_value;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Torus, class params>
|
||||
__host__ void host_extract_bits(
|
||||
void *v_stream,
|
||||
Torus *list_lwe_out,
|
||||
Torus *lwe_in,
|
||||
Torus *lwe_in_buffer,
|
||||
Torus *lwe_in_shifted_buffer,
|
||||
Torus *lwe_out_ks_buffer,
|
||||
Torus *lwe_out_pbs_buffer,
|
||||
Torus *lut_pbs,
|
||||
uint32_t *lut_vector_indexes,
|
||||
Torus *ksk,
|
||||
double2 *fourier_bsk,
|
||||
uint32_t number_of_bits,
|
||||
uint32_t delta_log,
|
||||
uint32_t lwe_dimension_before,
|
||||
uint32_t lwe_dimension_after,
|
||||
uint32_t base_log_bsk,
|
||||
uint32_t l_gadget_bsk,
|
||||
uint32_t base_log_ksk,
|
||||
uint32_t l_gadget_ksk,
|
||||
uint32_t number_of_samples)
|
||||
{
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
uint32_t ciphertext_n_bits = sizeof(Torus) * 8;
|
||||
|
||||
int blocks = 1;
|
||||
int threads = params::degree / params::opt;
|
||||
|
||||
copy_and_shift_lwe<Torus, params><<<blocks, threads, 0, *stream>>>
|
||||
(lwe_in_buffer, lwe_in_shifted_buffer, lwe_in,
|
||||
1ll << (ciphertext_n_bits - delta_log - 1));
|
||||
|
||||
for (int bit_idx = 0; bit_idx < number_of_bits; bit_idx++) {
|
||||
cuda_keyswitch_lwe_ciphertext_vector(v_stream, lwe_out_ks_buffer,
|
||||
lwe_in_shifted_buffer, ksk,
|
||||
lwe_dimension_before,
|
||||
lwe_dimension_after, base_log_ksk,
|
||||
l_gadget_ksk, 1);
|
||||
|
||||
copy_small_lwe<<<1, 256, 0, *stream>>>(list_lwe_out,
|
||||
lwe_out_ks_buffer,
|
||||
lwe_dimension_after + 1,
|
||||
number_of_bits,
|
||||
number_of_bits - bit_idx - 1);
|
||||
|
||||
if (bit_idx == number_of_bits - 1) {
|
||||
break;
|
||||
}
|
||||
|
||||
add_to_body<Torus><<<1, 1, 0, *stream>>>(lwe_out_ks_buffer,
|
||||
lwe_dimension_after,
|
||||
1ll << (ciphertext_n_bits - 2));
|
||||
|
||||
|
||||
fill_lut_body_for_current_bit<Torus, params>
|
||||
<<<blocks, threads, 0,*stream>>> (lut_pbs, 0ll - 1ll << (
|
||||
delta_log - 1 +
|
||||
bit_idx));
|
||||
|
||||
host_bootstrap_low_latency<Torus, params>(v_stream, lwe_out_pbs_buffer,
|
||||
lut_pbs, lut_vector_indexes,
|
||||
lwe_out_ks_buffer, fourier_bsk,
|
||||
lwe_dimension_after, lwe_dimension_before,
|
||||
base_log_bsk, l_gadget_bsk, number_of_samples,
|
||||
1);
|
||||
|
||||
add_sub_and_mul_lwe<Torus, params><<<1, threads, 0, *stream>>>(
|
||||
lwe_in_shifted_buffer, lwe_in_buffer, lwe_out_pbs_buffer,
|
||||
1ll << (delta_log - 1 + bit_idx),
|
||||
1ll << (ciphertext_n_bits - delta_log - bit_idx - 2) );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
#endif //WO_PBS_H
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "polynomial/polynomial.cuh"
|
||||
#include <cstdint>
|
||||
|
||||
#pragma once
|
||||
template <typename T, class params> class GadgetMatrix {
|
||||
private:
|
||||
uint32_t l_gadget;
|
||||
|
||||
@@ -17,6 +17,18 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
|
||||
return ptr;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__device__ Torus decompose_one(Torus &state, Torus mod_b_mask,
|
||||
int base_log) {
|
||||
Torus res = state & mod_b_mask;
|
||||
state >>= base_log;
|
||||
Torus carry = ((res - 1ll) | state) & res;
|
||||
carry >>= base_log - 1;
|
||||
state += carry;
|
||||
res -= carry << base_log;
|
||||
return res;
|
||||
}
|
||||
|
||||
/*
|
||||
* keyswitch kernel
|
||||
* Each thread handles a piece of the following equation:
|
||||
@@ -76,9 +88,15 @@ __global__ void keyswitch(Torus *lwe_out, Torus *lwe_in,
|
||||
Torus a_i = round_to_closest_multiple(block_lwe_in[i], base_log,
|
||||
l_gadget);
|
||||
|
||||
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * l_gadget);
|
||||
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
||||
|
||||
for (int j = 0; j < l_gadget; j++) {
|
||||
auto ksk_block = get_ith_block(ksk, i, j, lwe_dimension_after, l_gadget);
|
||||
Torus decomposed = gadget.decompose_one_level_single(a_i, (uint32_t)j);
|
||||
auto ksk_block = get_ith_block(ksk, i, l_gadget - j - 1,
|
||||
lwe_dimension_after,
|
||||
l_gadget);
|
||||
Torus decomposed = decompose_one<Torus>(state, mod_b_mask,
|
||||
base_log);
|
||||
for (int k = 0; k < lwe_part_per_thd; k++) {
|
||||
int idx = tid + k * blockDim.x;
|
||||
local_lwe_out[idx] -= (Torus)ksk_block[idx] * decomposed;
|
||||
|
||||
Reference in New Issue
Block a user