feat(gpu): support keyswitch 64/32

This commit is contained in:
Andrei Stoian
2025-10-10 17:25:03 +02:00
committed by Andrei Stoian
parent 14d49f0891
commit 78d1ce18c1
61 changed files with 2384 additions and 1686 deletions

View File

@@ -5,21 +5,14 @@
extern "C" {
void cuda_keyswitch_lwe_ciphertext_vector_32(
void cuda_keyswitch_lwe_ciphertext_vector_64_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples);
void cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, const void *ks_tmp_buffer, bool uses_trivial_indexes);
void cuda_keyswitch_lwe_ciphertext_vector_64(
void cuda_keyswitch_lwe_ciphertext_vector_64_32(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
@@ -31,6 +24,20 @@ uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t num_lwes, bool allocate_gpu_memory);
void cuda_keyswitch_gemm_lwe_ciphertext_vector_64_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, const void *ks_tmp_buffer, bool uses_trivial_indexes);
void cuda_keyswitch_gemm_lwe_ciphertext_vector_64_32(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, const void *ks_tmp_buffer, bool uses_trivial_indexes);
uint64_t scratch_cuda_keyswitch_gemm_64(void *stream, uint32_t gpu_index,
void **ks_tmp_memory,
uint32_t lwe_dimension_in,
@@ -65,6 +72,10 @@ void cleanup_packing_keyswitch_lwe_list_to_glwe(void *stream,
uint32_t gpu_index,
int8_t **fp_ks_buffer,
bool gpu_memory_allocated);
void cuda_closest_representable_64(void *stream, uint32_t gpu_index,
void const *input, void *output,
uint32_t base_log, uint32_t level_count);
}
#endif // CNCRT_KS_H_

View File

@@ -105,11 +105,11 @@ aes_xor(CudaStreams streams, int_aes_encrypt_buffer<Torus> *mem,
* result.
*
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ __forceinline__ void
aes_flush_inplace(CudaStreams streams, CudaRadixCiphertextFFI *data,
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
integer_radix_apply_univariate_lookup_table<Torus>(streams, data, data, bsks,
ksks, mem->luts->flush_lut,
@@ -121,10 +121,12 @@ aes_flush_inplace(CudaStreams streams, CudaRadixCiphertextFFI *data,
* ciphertext, then flushes the result to ensure it's a valid bit.
*
*/
template <typename Torus>
__host__ __forceinline__ void aes_scalar_add_one_flush_inplace(
CudaStreams streams, CudaRadixCiphertextFFI *data,
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, Torus *const *ksks) {
template <typename Torus, typename KSTorus>
__host__ __forceinline__ void
aes_scalar_add_one_flush_inplace(CudaStreams streams,
CudaRadixCiphertextFFI *data,
int_aes_encrypt_buffer<Torus> *mem,
void *const *bsks, KSTorus *const *ksks) {
host_add_scalar_one_inplace<Torus>(streams, data, mem->params.message_modulus,
mem->params.carry_modulus);
@@ -142,11 +144,11 @@ __host__ __forceinline__ void aes_scalar_add_one_flush_inplace(
* ciphertext locations.
*
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
batch_vec_flush_inplace(CudaStreams streams, CudaRadixCiphertextFFI **targets,
size_t count, int_aes_encrypt_buffer<Torus> *mem,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
uint32_t num_radix_blocks = targets[0]->num_radix_blocks;
@@ -185,13 +187,13 @@ batch_vec_flush_inplace(CudaStreams streams, CudaRadixCiphertextFFI **targets,
* Batches multiple "and" operations into a single, large launch.
*
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void batch_vec_and_inplace(CudaStreams streams,
CudaRadixCiphertextFFI **outs,
CudaRadixCiphertextFFI **lhs,
CudaRadixCiphertextFFI **rhs, size_t count,
int_aes_encrypt_buffer<Torus> *mem,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
uint32_t num_aes_inputs = outs[0]->num_radix_blocks;
@@ -274,13 +276,13 @@ __host__ void batch_vec_and_inplace(CudaStreams streams,
* [ptr] -> [R2b0, R2b1, R2b2, R2b3, R2b4, R2b5, R2b6, R2b7]
* ...
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void vectorized_sbox_n_bytes(CudaStreams streams,
CudaRadixCiphertextFFI **sbox_io_bytes,
uint32_t num_bytes_parallel,
uint32_t num_aes_inputs,
int_aes_encrypt_buffer<Torus> *mem,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
uint32_t num_sbox_blocks = num_bytes_parallel * num_aes_inputs;
@@ -702,12 +704,12 @@ __host__ void vectorized_mul_by_2(CudaStreams streams,
* [ s'_3 ] [ 03 01 01 02 ] [ s_3 ]
*
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void vectorized_mix_columns(CudaStreams streams,
CudaRadixCiphertextFFI *s_bits,
uint32_t num_aes_inputs,
int_aes_encrypt_buffer<Torus> *mem,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
constexpr uint32_t BITS_PER_BYTE = 8;
constexpr uint32_t BYTES_PER_COLUMN = 4;
@@ -842,11 +844,12 @@ __host__ void vectorized_mix_columns(CudaStreams streams,
* - AddRoundKey
*
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void vectorized_aes_encrypt_inplace(
CudaStreams streams, CudaRadixCiphertextFFI *all_states_bitsliced,
CudaRadixCiphertextFFI const *round_keys, uint32_t num_aes_inputs,
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, Torus *const *ksks) {
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks,
KSTorus *const *ksks) {
constexpr uint32_t BITS_PER_BYTE = 8;
constexpr uint32_t STATE_BYTES = 16;
@@ -987,11 +990,12 @@ __host__ void vectorized_aes_encrypt_inplace(
* The "transposed_states" buffer is updated in-place with the sum bits $S_i$.
*
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void vectorized_aes_full_adder_inplace(
CudaStreams streams, CudaRadixCiphertextFFI *transposed_states,
const Torus *counter_bits_le_all_blocks, uint32_t num_aes_inputs,
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, Torus *const *ksks) {
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks,
KSTorus *const *ksks) {
constexpr uint32_t NUM_BITS = 128;
@@ -1091,12 +1095,13 @@ __host__ void vectorized_aes_full_adder_inplace(
* +---------------------------------+
*
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_integer_aes_ctr_encrypt(
CudaStreams streams, CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI const *iv, CudaRadixCiphertextFFI const *round_keys,
const Torus *counter_bits_le_all_blocks, uint32_t num_aes_inputs,
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, Torus *const *ksks) {
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks,
KSTorus *const *ksks) {
constexpr uint32_t NUM_BITS = 128;
@@ -1148,13 +1153,13 @@ uint64_t scratch_cuda_integer_key_expansion(
* - If (i % 4 == 0): w_i = w_{i-4} + SubWord(RotWord(w_{i-1})) + Rcon[i/4]
* - If (i % 4 != 0): w_i = w_{i-4} + w_{i-1}
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_integer_key_expansion(CudaStreams streams,
CudaRadixCiphertextFFI *expanded_keys,
CudaRadixCiphertextFFI const *key,
int_key_expansion_buffer<Torus> *mem,
void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
constexpr uint32_t BITS_PER_WORD = 32;
constexpr uint32_t BITS_PER_BYTE = 8;

View File

@@ -2,26 +2,9 @@
#include "keyswitch/keyswitch.h"
#include "packing_keyswitch.cuh"
/* Perform keyswitch on a batch of 32 bits input LWE ciphertexts.
* Head out to the equivalent operation on 64 bits for more details.
*/
void cuda_keyswitch_lwe_ciphertext_vector_32(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void *lwe_output_indexes, void *lwe_array_in, void *lwe_input_indexes,
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
void *ksk_tmp_buffer, bool uses_trivial_indices) {
host_gemm_keyswitch_lwe_ciphertext_vector<uint32_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint32_t *>(lwe_array_out),
static_cast<uint32_t *>(lwe_output_indexes),
static_cast<uint32_t *>(lwe_array_in),
static_cast<uint32_t *>(lwe_input_indexes), static_cast<uint32_t *>(ksk),
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples,
static_cast<uint32_t *>(ksk_tmp_buffer), uses_trivial_indices);
}
/* Perform keyswitch on a batch of 64 bits input LWE ciphertexts.
/* Perform keyswitch on a batch of 64 bits input LWE ciphertexts
* using a 64-b key-switching key. Uses the GEMM approach which
* achieves good throughput on large batches (128 LWEs on H100)
*
* - `v_stream` is a void pointer to the Cuda stream to be used in the kernel
* launch
@@ -37,7 +20,7 @@ void cuda_keyswitch_lwe_ciphertext_vector_32(
* This function calls a wrapper to a device kernel that performs the keyswitch
* - num_samples blocks of threads are launched
*/
void cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
void cuda_keyswitch_gemm_lwe_ciphertext_vector_64_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
@@ -45,7 +28,7 @@ void cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
uint32_t num_samples, const void *ks_tmp_buffer,
bool uses_trivial_indices) {
host_gemm_keyswitch_lwe_ciphertext_vector<uint64_t>(
host_gemm_keyswitch_lwe_ciphertext_vector<uint64_t, uint64_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(lwe_array_out),
static_cast<const uint64_t *>(lwe_output_indexes),
@@ -57,13 +40,37 @@ void cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
uses_trivial_indices);
}
void cuda_keyswitch_lwe_ciphertext_vector_64(
/* Perform keyswitch on a batch of 64 bits input LWE ciphertexts
* using a 32-b key-switching key, producing 32-bit LWE outputs.
* Uses the GEMM approach which achieves good throughput on large batches
*/
void cuda_keyswitch_gemm_lwe_ciphertext_vector_64_32(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, const void *ks_tmp_buffer,
bool uses_trivial_indices) {
host_gemm_keyswitch_lwe_ciphertext_vector<uint64_t, uint32_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint32_t *>(lwe_array_out),
static_cast<const uint64_t *>(lwe_output_indexes),
static_cast<const uint64_t *>(lwe_array_in),
static_cast<const uint64_t *>(lwe_input_indexes),
static_cast<const uint32_t *>(ksk), lwe_dimension_in, lwe_dimension_out,
base_log, level_count, num_samples,
static_cast<const ks_mem<uint64_t> *>(ks_tmp_buffer)->d_buffer,
uses_trivial_indices);
}
void cuda_keyswitch_lwe_ciphertext_vector_64_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples) {
host_keyswitch_lwe_ciphertext_vector<uint64_t>(
host_keyswitch_lwe_ciphertext_vector<uint64_t, uint64_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(lwe_array_out),
static_cast<uint64_t const *>(lwe_output_indexes),
@@ -73,6 +80,22 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
base_log, level_count, num_samples);
}
void cuda_keyswitch_lwe_ciphertext_vector_64_32(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples) {
host_keyswitch_lwe_ciphertext_vector<uint64_t, uint32_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint32_t *>(lwe_array_out),
static_cast<const uint64_t *>(lwe_output_indexes),
static_cast<const uint64_t *>(lwe_array_in),
static_cast<const uint64_t *>(lwe_input_indexes),
static_cast<const uint32_t *>(ksk), lwe_dimension_in, lwe_dimension_out,
base_log, level_count, num_samples);
}
uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
void *stream, uint32_t gpu_index, int8_t **fp_ks_buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
@@ -159,3 +182,12 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_128(
input_lwe_dimension, output_glwe_dimension, output_polynomial_size,
base_log, level_count, num_lwes);
}
void cuda_closest_representable_64(void *stream, uint32_t gpu_index,
void const *input, void *output,
uint32_t base_log, uint32_t level_count) {
host_cuda_closest_representable(static_cast<cudaStream_t>(stream), gpu_index,
static_cast<const uint64_t *>(input),
static_cast<uint64_t *>(output), base_log,
level_count);
}

View File

@@ -12,7 +12,6 @@
#include "utils/helper.cuh"
#include "utils/kernel_dimensions.cuh"
#include <thread>
#include <unistd.h>
#include <vector>
const int BLOCK_SIZE_DECOMP = 8;
@@ -46,10 +45,42 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
return ptr;
}
template <typename T>
__device__ T closest_repr(T input, uint32_t base_log, uint32_t level_count) {
T minus_2 = static_cast<T>(-2);
const T rep_bit_count = level_count * base_log; // 32
const T non_rep_bit_count = sizeof(T) * 8 - rep_bit_count; // 32
auto shift = (non_rep_bit_count - 1); // 31
T res = input >> shift;
res++;
res &= minus_2;
res <<= shift;
return res;
}
template <typename T>
__global__ void closest_representable(const T *input, T *output,
uint32_t base_log, uint32_t level_count) {
output[0] = closest_repr(input[0], base_log, level_count);
}
template <typename T>
__host__ void
host_cuda_closest_representable(cudaStream_t stream, uint32_t gpu_index,
const T *input, T *output, uint32_t base_log,
uint32_t level_count) {
dim3 grid(1, 1, 1);
dim3 threads(1, 1, 1);
cuda_set_device(gpu_index);
closest_representable<<<grid, threads, 0, stream>>>(input, output, base_log,
level_count);
}
// Initialize decomposition by performing rounding
// and decomposing one level of an array of Torus LWEs. Only
// decomposes the mask elements of the incoming LWEs.
template <typename Torus>
template <typename Torus, typename KSTorus>
__global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
uint32_t lwe_dimension,
uint32_t num_lwe, uint32_t base_log,
@@ -76,7 +107,9 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
Torus state = init_decomposer_state(a_i, base_log, level_count);
Torus mod_b_mask = (1ll << base_log) - 1ll;
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
KSTorus *kst_ptr_lwe_out = (KSTorus *)lwe_out;
kst_ptr_lwe_out[write_val_idx] =
decompose_one<Torus>(state, mod_b_mask, base_log);
__syncthreads();
lwe_out[write_state_idx] = state;
}
@@ -86,7 +119,7 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
// from num_lwe. The maximum index should be <= total_lwe. num_lwe is the number
// of LWEs to decompose The output buffer should have space for num_lwe LWEs.
// These will be sorted according to the input indices.
template <typename Torus>
template <typename Torus, typename KSTorus>
__global__ void decompose_vectorize_init_with_indices(
Torus const *lwe_in, const Torus *__restrict__ lwe_input_indices,
Torus *lwe_out, uint32_t lwe_dimension, uint32_t num_lwe, uint32_t base_log,
@@ -114,7 +147,9 @@ __global__ void decompose_vectorize_init_with_indices(
Torus state = init_decomposer_state(a_i, base_log, level_count);
Torus mod_b_mask = (1ll << base_log) - 1ll;
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
KSTorus *kst_ptr_lwe_out = (KSTorus *)lwe_out;
kst_ptr_lwe_out[write_val_idx] =
decompose_one<Torus>(state, mod_b_mask, base_log);
__syncthreads();
lwe_out[write_state_idx] = state;
}
@@ -122,7 +157,7 @@ __global__ void decompose_vectorize_init_with_indices(
// Continue decomposition of an array of Torus elements in place. Supposes
// that the array contains already decomposed elements and
// computes the new decomposed level in place.
template <typename Torus>
template <typename Torus, typename KSTorus>
__global__ void
decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
uint32_t num_lwe, uint32_t base_log,
@@ -144,15 +179,22 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
Torus mod_b_mask = (1ll << base_log) - 1ll;
buffer_in[val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
KSTorus *kst_ptr_lwe_out = (KSTorus *)buffer_in;
kst_ptr_lwe_out[val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
__syncthreads();
buffer_in[state_idx] = state;
}
template <typename Torus>
/* LWEs inputs to the keyswitch function are stored as a_0,...,a_{lwe_dim},b,
* where a_i are mask elements and b is the message. We initialize
* the output keyswitched LWEs to 0, ..., 0, -b. The GEMM keyswitch is computed
* as:
* -(-b + sum(a_i A_KSK))
*/
template <typename Torus, typename KSTorus>
__global__ void keyswitch_gemm_copy_negated_message_with_indices(
const Torus *__restrict__ lwe_in,
const Torus *__restrict__ lwe_input_indices, Torus *__restrict__ lwe_out,
const Torus *__restrict__ lwe_input_indices, KSTorus *__restrict__ lwe_out,
const Torus *__restrict__ lwe_output_indices,
uint32_t lwe_dimension_in, uint32_t num_lwes, uint32_t lwe_dimension_out) {
@@ -165,16 +207,39 @@ __global__ void keyswitch_gemm_copy_negated_message_with_indices(
uint32_t lwe_in_idx = lwe_input_indices[lwe_id];
uint32_t lwe_out_idx = lwe_output_indices[lwe_id];
Torus body_in =
lwe_in[lwe_in_idx * (lwe_dimension_in + 1) + lwe_dimension_in];
Torus body_out;
if constexpr (std::is_same_v<KSTorus, Torus>) {
body_out = -body_in;
} else {
body_out = closest_repr(
lwe_in[lwe_in_idx * (lwe_dimension_in + 1) + lwe_dimension_in],
sizeof(KSTorus) * 8, 1);
// Power of two are encoded in the MSBs of the types so we need to scale
// the type to the other one without having to worry about the moduli
static_assert(sizeof(Torus) >= sizeof(KSTorus),
"Cannot compile keyswitch with given input/output dtypes");
Torus input_to_output_scaling_factor =
(sizeof(Torus) - sizeof(KSTorus)) * 8;
auto rounded_downscaled_body =
(KSTorus)(body_out >> input_to_output_scaling_factor);
body_out = -rounded_downscaled_body;
}
lwe_out[lwe_out_idx * (lwe_dimension_out + 1) + lwe_dimension_out] =
-lwe_in[lwe_in_idx * (lwe_dimension_in + 1) + lwe_dimension_in];
(KSTorus)body_out;
}
// Finishes the KS computation by negating all elements in the array
// using output indices. The array contains -b + SUM(a_i x LWE_i)
// and this final step computes b - SUM(a_i x LWE_i)
template <typename Torus>
// The GEMM keyswitch is computed as: -(-b + sum(a_i A_KSK)).
// This function finishes the KS computation by negating all elements in the
// array using output indices. The array contains -b + SUM(a_i x LWE_i) and this
// final step computes b - SUM(a_i x LWE_i).
template <typename Torus, typename KSTorus>
__global__ void keyswitch_negate_with_output_indices(
Torus *buffer_in, const Torus *__restrict__ lwe_output_indices,
KSTorus *buffer_in, const Torus *__restrict__ lwe_output_indices,
uint32_t lwe_size, uint32_t num_lwe) {
// index of this LWE ct in the buffer
@@ -191,9 +256,9 @@ __global__ void keyswitch_negate_with_output_indices(
buffer_in[val_idx] = -val;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__global__ void keyswitch_zero_output_with_output_indices(
Torus *buffer_in, const Torus *__restrict__ lwe_output_indices,
KSTorus *buffer_in, const Torus *__restrict__ lwe_output_indices,
uint32_t lwe_size, uint32_t num_lwe) {
// index of this LWE ct in the buffer
@@ -235,12 +300,12 @@ __global__ void keyswitch_zero_output_with_output_indices(
// in two parts, a constant part is calculated before the loop, and a variable
// part is calculated inside the loop. This seems to help with the register
// pressure as well.
template <typename Torus>
template <typename Torus, typename KSTorus>
__global__ void
keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes,
const Torus *__restrict__ ksk, uint32_t lwe_dimension_in,
const KSTorus *__restrict__ ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count) {
const int tid = threadIdx.x + blockIdx.y * blockDim.x;
const int shmem_index = threadIdx.x + threadIdx.y * blockDim.x;
@@ -252,12 +317,27 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
if (tid <= lwe_dimension_out) {
Torus local_lwe_out = 0;
KSTorus local_lwe_out = 0;
auto block_lwe_array_in = get_chunk(
lwe_array_in, lwe_input_indexes[blockIdx.x], lwe_dimension_in + 1);
if (tid == lwe_dimension_out && threadIdx.y == 0) {
local_lwe_out = -block_lwe_array_in[lwe_dimension_in];
if constexpr (std::is_same_v<KSTorus, Torus>) {
local_lwe_out = -block_lwe_array_in[lwe_dimension_in];
} else {
auto new_body = closest_repr(block_lwe_array_in[lwe_dimension_in],
sizeof(KSTorus) * 8, 1);
// Power of two are encoded in the MSBs of the types so we need to scale
// the type to the other one without having to worry about the moduli
Torus input_to_output_scaling_factor =
(sizeof(Torus) - sizeof(KSTorus)) * 8;
auto rounded_downscaled_body =
(KSTorus)(new_body >> input_to_output_scaling_factor);
local_lwe_out = -rounded_downscaled_body;
}
}
const Torus mask_mod_b = (1ll << base_log) - 1ll;
@@ -273,9 +353,10 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
uint32_t offset = i * level_count * (lwe_dimension_out + 1);
for (int j = 0; j < level_count; j++) {
Torus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
KSTorus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
local_lwe_out +=
(Torus)ksk[tid + j * (lwe_dimension_out + 1) + offset] * decomposed;
(KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] *
decomposed;
}
}
@@ -294,13 +375,13 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
}
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_keyswitch_lwe_ciphertext_vector(
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
cudaStream_t stream, uint32_t gpu_index, KSTorus *lwe_array_out,
Torus const *lwe_output_indexes, Torus const *lwe_array_in,
Torus const *lwe_input_indexes, Torus const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples) {
Torus const *lwe_input_indexes, KSTorus const *ksk,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples) {
cuda_set_device(gpu_index);
@@ -322,29 +403,36 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
dim3 grid(num_samples, num_blocks_per_sample, 1);
dim3 threads(num_threads_x, num_threads_y, 1);
keyswitch<Torus><<<grid, threads, shared_mem, stream>>>(
keyswitch<Torus, KSTorus><<<grid, threads, shared_mem, stream>>>(
lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk,
lwe_dimension_in, lwe_dimension_out, base_log, level_count);
check_cuda_error(cudaGetLastError());
}
template <typename Torus>
// The GEMM keyswitch is computed as: -(-b + sum(a_i A_KSK))
template <typename Torus, typename KSTorus>
__host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
cudaStream_t stream, uint32_t gpu_index, KSTorus *lwe_array_out,
Torus const *lwe_output_indices, Torus const *lwe_array_in,
Torus const *lwe_input_indices, Torus const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, Torus *fp_tmp_buffer, bool uses_trivial_indices) {
Torus const *lwe_input_indices, KSTorus const *ksk,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples, Torus *fp_tmp_buffer,
bool uses_trivial_indices) {
cuda_set_device(gpu_index);
check_cuda_error(cudaGetLastError());
auto d_mem_0 = fp_tmp_buffer; // keeps decomposed value
// fp_tmp_buffer contains 2x the space to store the input LWE masks without
// the body the first half can be interpreted with a smaller dtype when
// performing 64->32 KS the second half, storing decomposition state, must be
// interpreted as Torus* (usually 64b)
KSTorus *d_mem_0 =
(KSTorus *)fp_tmp_buffer; // keeps decomposed value (in KSTorus type)
// Set the scratch buffer to 0 as it is used to accumulate
// decomposition temporary results
if (uses_trivial_indices) {
cuda_memset_async(lwe_array_out, 0,
num_samples * (lwe_dimension_out + 1) * sizeof(Torus),
num_samples * (lwe_dimension_out + 1) * sizeof(KSTorus),
stream, gpu_index);
} else {
// gemm to ks the individual LWEs to GLWEs
@@ -352,7 +440,7 @@ __host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
dim3 threads_zero(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
keyswitch_zero_output_with_output_indices<Torus>
keyswitch_zero_output_with_output_indices<Torus, KSTorus>
<<<grid_zero, threads_zero, 0, stream>>>(
lwe_array_out, lwe_output_indices, lwe_dimension_out + 1,
num_samples);
@@ -364,8 +452,8 @@ __host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
// lwe_array_out is num_samples x (lwe_dimension_out + 1). copy the bodies
// lwe_array_in[:,lwe_dimension_in] to lwe_array_out[:,lwe_dimension_out]
// and negate
keyswitch_gemm_copy_negated_message_with_indices<Torus>
// and negates them
keyswitch_gemm_copy_negated_message_with_indices<Torus, KSTorus>
<<<grid_copy, threads_copy, 0, stream>>>(
lwe_array_in, lwe_input_indices, lwe_array_out, lwe_output_indices,
lwe_dimension_in, num_samples, lwe_dimension_out);
@@ -394,21 +482,21 @@ __host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
dim3 threads_gemm(BLOCK_SIZE_GEMM_KS * THREADS_GEMM_KS);
// decompose first level (skips the body in the input buffer)
decompose_vectorize_init_with_indices<Torus>
decompose_vectorize_init_with_indices<Torus, KSTorus>
<<<grid_decomp, threads_decomp, 0, stream>>>(
lwe_array_in, lwe_input_indices, fp_tmp_buffer, lwe_dimension_in,
num_samples, base_log, level_count);
check_cuda_error(cudaGetLastError());
if (uses_trivial_indices) {
tgemm<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
tgemm<KSTorus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
ksk, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1);
check_cuda_error(cudaGetLastError());
} else {
tgemm_with_indices<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
tgemm_with_indices<KSTorus, Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
ksk, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1,
@@ -419,14 +507,14 @@ __host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
auto ksk_block_size = (lwe_dimension_out + 1);
for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus>
decompose_vectorize_step_inplace<Torus, KSTorus>
<<<grid_decomp, threads_decomp, 0, stream>>>(
fp_tmp_buffer, lwe_dimension_in, num_samples, base_log,
level_count);
check_cuda_error(cudaGetLastError());
if (uses_trivial_indices) {
tgemm<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
tgemm<KSTorus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
ksk + li * ksk_block_size, stride_KSK_buffer, lwe_array_out,
@@ -434,7 +522,7 @@ __host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
check_cuda_error(cudaGetLastError());
} else {
tgemm_with_indices<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
tgemm_with_indices<KSTorus, Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
ksk + li * ksk_block_size, stride_KSK_buffer, lwe_array_out,
@@ -447,20 +535,22 @@ __host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
dim3 grid_negate(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_DECOMP),
CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
dim3 threads_negate(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
// Negate all outputs in the LWE
keyswitch_negate_with_output_indices<Torus>
// Negate all outputs in the output LWEs. This is the final step in the GEMM
// keyswitch computed as: -(-b + sum(a_i A_KSK))
keyswitch_negate_with_output_indices<Torus, KSTorus>
<<<grid_negate, threads_negate, 0, stream>>>(
lwe_array_out, lwe_output_indices, lwe_dimension_out + 1,
num_samples);
check_cuda_error(cudaGetLastError());
}
template <typename Torus>
template <typename Torus, typename KSTorus>
void execute_keyswitch_async(
CudaStreams streams, const LweArrayVariant<Torus> &lwe_array_out,
const LweArrayVariant<Torus> &lwe_output_indexes,
const LweArrayVariant<Torus> &lwe_array_in,
const LweArrayVariant<Torus> &lwe_input_indexes, Torus *const *ksks,
const LweArrayVariant<Torus> &lwe_input_indexes, KSTorus *const *ksks,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples, bool uses_trivial_indices,
const std::vector<ks_mem<Torus> *> &fp_tmp_buffer) {

View File

@@ -124,8 +124,10 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
dim3 threads_decomp(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
// decompose first level
decompose_vectorize_init<Torus><<<grid_decomp, threads_decomp, 0, stream>>>(
lwe_array_in, d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
decompose_vectorize_init<Torus, Torus>
<<<grid_decomp, threads_decomp, 0, stream>>>(lwe_array_in, d_mem_0,
lwe_dimension, num_lwes,
base_log, level_count);
check_cuda_error(cudaGetLastError());
// gemm to ks the individual LWEs to GLWEs
@@ -151,7 +153,7 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
auto ksk_block_size = glwe_accumulator_size;
for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus>
decompose_vectorize_step_inplace<Torus, Torus>
<<<grid_decomp, threads_decomp, 0, stream>>>(
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());

View File

@@ -22,13 +22,13 @@ __host__ uint64_t scratch_cuda_boolean_bitop(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_boolean_bitop(CudaStreams streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_1,
CudaRadixCiphertextFFI const *lwe_array_2,
boolean_bitop_buffer<Torus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
PANIC_IF_FALSE(
lwe_array_out->num_radix_blocks == lwe_array_1->num_radix_blocks &&
@@ -203,11 +203,11 @@ __host__ uint64_t scratch_cuda_boolean_bitnot(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_boolean_bitnot(CudaStreams streams,
CudaRadixCiphertextFFI *lwe_array,
boolean_bitnot_buffer<Torus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
bool carries_empty = true;
for (size_t i = 0; i < lwe_array->num_radix_blocks; ++i) {
if (lwe_array->degrees[i] >= mem_ptr->params.message_modulus) {
@@ -228,13 +228,13 @@ __host__ void host_boolean_bitnot(CudaStreams streams,
// this function calls `host_bitnot` with `ct_message_modulus = 2`
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_bitop(CudaStreams streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_1,
CudaRadixCiphertextFFI const *lwe_array_2,
int_bitop_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
PANIC_IF_FALSE(
lwe_array_out->num_radix_blocks == lwe_array_1->num_radix_blocks &&

View File

@@ -68,12 +68,12 @@ __host__ uint64_t scratch_extend_radix_with_sign_msb(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_extend_radix_with_sign_msb(
CudaStreams streams, CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI const *input,
int_extend_radix_with_sign_msb_buffer<Torus> *mem_ptr,
uint32_t num_additional_blocks, void *const *bsks, Torus *const *ksks) {
uint32_t num_additional_blocks, void *const *bsks, KSTorus *const *ksks) {
if (num_additional_blocks == 0) {
PUSH_RANGE("cast/extend no addblocks")

View File

@@ -5,14 +5,14 @@
#include "integer/cmux.h"
#include "radix_ciphertext.cuh"
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void zero_out_if(CudaStreams streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_input,
CudaRadixCiphertextFFI const *lwe_condition,
int_zero_out_if_buffer<Torus> *mem_ptr,
int_radix_lut<Torus> *predicate, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks) {
PANIC_IF_FALSE(
lwe_array_out->num_radix_blocks >= num_radix_blocks &&
lwe_array_input->num_radix_blocks >= num_radix_blocks,
@@ -41,14 +41,14 @@ __host__ void zero_out_if(CudaStreams streams,
num_radix_blocks);
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_cmux(CudaStreams streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_condition,
CudaRadixCiphertextFFI const *lwe_array_true,
CudaRadixCiphertextFFI const *lwe_array_false,
int_cmux_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
if (lwe_array_out->num_radix_blocks != lwe_array_true->num_radix_blocks)
PANIC("Cuda error: input and output num radix blocks must be the same")

View File

@@ -56,12 +56,12 @@ __host__ void accumulate_all_blocks(cudaStream_t stream, uint32_t gpu_index,
* blocks are 1 otherwise the block encrypts 0
*
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void are_all_comparisons_block_true(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in,
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks) {
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
PANIC("Cuda error: input and output lwe dimensions must be the same")
@@ -191,12 +191,12 @@ __host__ void are_all_comparisons_block_true(
* It writes in lwe_array_out a single lwe ciphertext encrypting 1 if at least
* one input ciphertext encrypts 1 otherwise encrypts 0
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void is_at_least_one_comparisons_block_true(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in,
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks) {
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
PANIC("Cuda error: input lwe dimensions must be the same")
@@ -260,12 +260,12 @@ __host__ void is_at_least_one_comparisons_block_true(
}
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_compare_blocks_with_zero(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in,
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, int32_t num_radix_blocks,
KSTorus *const *ksks, int32_t num_radix_blocks,
int_radix_lut<Torus> *zero_comparison) {
if (num_radix_blocks == 0)
@@ -327,13 +327,13 @@ __host__ void host_compare_blocks_with_zero(
reset_radix_ciphertext_blocks(lwe_array_out, num_sum_blocks);
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
host_equality_check(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_1,
CudaRadixCiphertextFFI const *lwe_array_2,
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks) {
if (lwe_array_out->lwe_dimension != lwe_array_1->lwe_dimension ||
lwe_array_out->lwe_dimension != lwe_array_2->lwe_dimension)
@@ -355,13 +355,13 @@ host_equality_check(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
mem_ptr, bsks, ksks, num_radix_blocks);
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
compare_radix_blocks(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_left,
CudaRadixCiphertextFFI const *lwe_array_right,
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks) {
if (lwe_array_out->lwe_dimension != lwe_array_left->lwe_dimension ||
lwe_array_out->lwe_dimension != lwe_array_right->lwe_dimension)
@@ -407,13 +407,13 @@ compare_radix_blocks(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
// Reduces a vec containing shortint blocks that encrypts a sign
// (inferior, equal, superior) to one single shortint block containing the
// final sign
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
tree_sign_reduction(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI *lwe_block_comparisons,
int_tree_sign_reduction_buffer<Torus> *tree_buffer,
std::function<Torus(Torus)> sign_handler_f,
void *const *bsks, Torus *const *ksks,
void *const *bsks, KSTorus *const *ksks,
uint32_t num_radix_blocks) {
if (lwe_array_out->lwe_dimension != lwe_block_comparisons->lwe_dimension)
@@ -496,14 +496,14 @@ tree_sign_reduction(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
bsks, ksks, last_lut, 1);
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_difference_check(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_left,
CudaRadixCiphertextFFI const *lwe_array_right,
int_comparison_buffer<Torus> *mem_ptr,
std::function<Torus(Torus)> reduction_lut_f, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks) {
if (lwe_array_out->lwe_dimension != lwe_array_left->lwe_dimension ||
lwe_array_out->lwe_dimension != lwe_array_right->lwe_dimension)
@@ -664,13 +664,13 @@ __host__ uint64_t scratch_cuda_comparison_check(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
host_maxmin(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_left,
CudaRadixCiphertextFFI const *lwe_array_right,
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks) {
if (lwe_array_out->lwe_dimension != lwe_array_left->lwe_dimension ||
lwe_array_out->lwe_dimension != lwe_array_right->lwe_dimension)
@@ -692,12 +692,12 @@ host_maxmin(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
ksks);
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_integer_are_all_comparisons_block_true(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in,
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks) {
// It returns a block encrypting 1 if all input blocks are 1
// otherwise the block encrypts 0
@@ -705,12 +705,12 @@ __host__ void host_integer_are_all_comparisons_block_true(
mem_ptr, bsks, ksks, num_radix_blocks);
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_integer_is_at_least_one_comparisons_block_true(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in,
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks) {
// It returns a block encrypting 1 if all input blocks are 1
// otherwise the block encrypts 0

View File

@@ -31,7 +31,7 @@ void cuda_integer_count_of_consecutive_bits_64(
CudaRadixCiphertextFFI const *input_ct, int8_t *mem_ptr, void *const *bsks,
void *const *ksks) {
host_integer_count_of_consecutive_bits<uint64_t>(
host_integer_count_of_consecutive_bits<uint64_t, uint64_t>(
CudaStreams(streams), output_ct, input_ct,
(int_count_of_consecutive_bits_buffer<uint64_t> *)mem_ptr, bsks,
(uint64_t **)ksks);
@@ -82,7 +82,7 @@ void cuda_integer_ilog2_64(
CudaRadixCiphertextFFI const *trivial_ct_m_minus_1_block, int8_t *mem_ptr,
void *const *bsks, void *const *ksks) {
host_integer_ilog2<uint64_t>(
host_integer_ilog2<uint64_t, uint64_t>(
CudaStreams(streams), output_ct, input_ct, trivial_ct_neg_n, trivial_ct_2,
trivial_ct_m_minus_1_block, (int_ilog2_buffer<uint64_t> *)mem_ptr, bsks,
(uint64_t **)ksks);

View File

@@ -6,11 +6,11 @@
#include "integer/integer_utilities.h"
#include "multiplication.cuh"
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_integer_prepare_count_of_consecutive_bits(
CudaStreams streams, CudaRadixCiphertextFFI *ciphertext,
int_prepare_count_of_consecutive_bits_buffer<Torus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
auto tmp = mem_ptr->tmp_ct;
@@ -42,12 +42,12 @@ __host__ uint64_t scratch_integer_count_of_consecutive_bits(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_integer_count_of_consecutive_bits(
CudaStreams streams, CudaRadixCiphertextFFI *output_ct,
CudaRadixCiphertextFFI const *input_ct,
int_count_of_consecutive_bits_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
auto params = mem_ptr->params;
auto ct_prepared = mem_ptr->ct_prepared;
@@ -98,7 +98,7 @@ __host__ uint64_t scratch_integer_ilog2(CudaStreams streams,
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
host_integer_ilog2(CudaStreams streams, CudaRadixCiphertextFFI *output_ct,
CudaRadixCiphertextFFI const *input_ct,
@@ -106,7 +106,7 @@ host_integer_ilog2(CudaStreams streams, CudaRadixCiphertextFFI *output_ct,
CudaRadixCiphertextFFI const *trivial_ct_2,
CudaRadixCiphertextFFI const *trivial_ct_m_minus_1_block,
int_ilog2_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
// Prepare the input ciphertext by computing the number of consecutive
// leading zeros for each of its blocks.

View File

@@ -502,11 +502,12 @@ __host__ void host_pack_bivariate_blocks_with_single_block(
/// num_radix_blocks corresponds to the number of blocks on which to apply the
/// LUT In scalar bitops we use a number of blocks that may be lower or equal to
/// the input and output numbers of blocks
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void integer_radix_apply_univariate_lookup_table(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in, void *const *bsks,
Torus *const *ksks, int_radix_lut<Torus> *lut, uint32_t num_radix_blocks) {
KSTorus *const *ksks, int_radix_lut<Torus> *lut,
uint32_t num_radix_blocks) {
PUSH_RANGE("apply lut")
// apply_lookup_table
auto params = lut->params;
@@ -607,11 +608,11 @@ __host__ void integer_radix_apply_univariate_lookup_table(
POP_RANGE()
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void integer_radix_apply_many_univariate_lookup_table(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in, void *const *bsks,
Torus *const *ksks, int_radix_lut<Torus> *lut, uint32_t num_many_lut,
KSTorus *const *ksks, int_radix_lut<Torus> *lut, uint32_t num_many_lut,
uint32_t lut_stride) {
PUSH_RANGE("apply many lut")
// apply_lookup_table
@@ -711,12 +712,12 @@ __host__ void integer_radix_apply_many_univariate_lookup_table(
POP_RANGE()
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void integer_radix_apply_bivariate_lookup_table(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_1,
CudaRadixCiphertextFFI const *lwe_array_2, void *const *bsks,
Torus *const *ksks, int_radix_lut<Torus> *lut, uint32_t num_radix_blocks,
KSTorus *const *ksks, int_radix_lut<Torus> *lut, uint32_t num_radix_blocks,
uint32_t shift) {
PUSH_RANGE("apply bivar lut")
if (lwe_array_out->lwe_dimension != lwe_array_1->lwe_dimension ||
@@ -1271,11 +1272,11 @@ void generate_many_lut_device_accumulator(
// block states: contains the propagation states for the different blocks
// depending on the group it belongs to and the internal position within the
// block.
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_compute_shifted_blocks_and_states(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array,
int_shifted_blocks_and_states_memory<Torus> *mem, void *const *bsks,
Torus *const *ksks, uint32_t lut_stride, uint32_t num_many_lut) {
KSTorus *const *ksks, uint32_t lut_stride, uint32_t num_many_lut) {
auto num_radix_blocks = lwe_array->num_radix_blocks;
@@ -1298,12 +1299,12 @@ void host_compute_shifted_blocks_and_states(
2 * num_radix_blocks);
}
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_resolve_group_carries_sequentially(
CudaStreams streams, CudaRadixCiphertextFFI *resolved_carries,
CudaRadixCiphertextFFI *grouping_pgns, int_radix_params params,
int_seq_group_prop_memory<Torus> *mem, void *const *bsks,
Torus *const *ksks, uint32_t num_groups) {
KSTorus *const *ksks, uint32_t num_groups) {
auto group_resolved_carries = mem->group_resolved_carries;
if (num_groups > 1) {
@@ -1366,11 +1367,11 @@ void host_resolve_group_carries_sequentially(
}
}
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_compute_prefix_sum_hillis_steele(
CudaStreams streams, CudaRadixCiphertextFFI *step_output,
CudaRadixCiphertextFFI *generates_or_propagates, int_radix_lut<Torus> *luts,
void *const *bsks, Torus *const *ksks, uint32_t num_radix_blocks) {
void *const *bsks, KSTorus *const *ksks, uint32_t num_radix_blocks) {
if (step_output->lwe_dimension != generates_or_propagates->lwe_dimension)
PANIC("Cuda error: input lwe dimensions must be the same")
@@ -1409,11 +1410,11 @@ void host_compute_prefix_sum_hillis_steele(
// - calculates the propagation state of each group
// - resolves the carries between groups, either sequentially or with hillis
// steele
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_compute_propagation_simulators_and_group_carries(
CudaStreams streams, CudaRadixCiphertextFFI *block_states,
int_radix_params params, int_prop_simu_group_carries_memory<Torus> *mem,
void *const *bsks, Torus *const *ksks, uint32_t num_radix_blocks,
void *const *bsks, KSTorus *const *ksks, uint32_t num_radix_blocks,
uint32_t num_groups) {
if (num_radix_blocks > block_states->num_radix_blocks)
@@ -1471,11 +1472,11 @@ void host_compute_propagation_simulators_and_group_carries(
// block states: contains the propagation states for the different blocks
// depending on the group it belongs to and the internal position within the
// block.
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_compute_shifted_blocks_and_borrow_states(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array,
int_shifted_blocks_and_borrow_states_memory<Torus> *mem, void *const *bsks,
Torus *const *ksks, uint32_t lut_stride, uint32_t num_many_lut) {
KSTorus *const *ksks, uint32_t lut_stride, uint32_t num_many_lut) {
auto num_radix_blocks = lwe_array->num_radix_blocks;
auto shifted_blocks_and_borrow_states = mem->shifted_blocks_and_borrow_states;
@@ -1504,11 +1505,11 @@ void host_compute_shifted_blocks_and_borrow_states(
* * (lwe_dimension + 1) * sizeof(Torus) big_lwe_vector: output of pbs should
* have size = 2 * (glwe_dimension * polynomial_size + 1) * sizeof(Torus)
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_full_propagate_inplace(CudaStreams streams,
CudaRadixCiphertextFFI *input_blocks,
int_fullprop_buffer<Torus> *mem_ptr,
Torus *const *ksks, void *const *bsks,
KSTorus *const *ksks, void *const *bsks,
uint32_t num_blocks) {
auto params = mem_ptr->lut->params;
@@ -1667,11 +1668,11 @@ __host__ void scalar_pack_blocks(cudaStream_t stream, uint32_t gpu_index,
* Thus, lwe_array_out must be allocated with num_radix_blocks * bits_per_block
* * (lwe_dimension+1) * sizeeof(Torus) bytes
*/
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
extract_n_bits(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
const CudaRadixCiphertextFFI *lwe_array_in, void *const *bsks,
Torus *const *ksks, uint32_t effective_num_radix_blocks,
KSTorus *const *ksks, uint32_t effective_num_radix_blocks,
uint32_t num_radix_blocks,
int_bit_extract_luts_buffer<Torus> *bit_extract) {
@@ -1691,13 +1692,13 @@ extract_n_bits(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
effective_num_radix_blocks);
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
reduce_signs(CudaStreams streams, CudaRadixCiphertextFFI *signs_array_out,
CudaRadixCiphertextFFI *signs_array_in,
int_comparison_buffer<Torus> *mem_ptr,
std::function<Torus(Torus)> sign_handler_f, void *const *bsks,
Torus *const *ksks, uint32_t num_sign_blocks) {
KSTorus *const *ksks, uint32_t num_sign_blocks) {
if (signs_array_out->lwe_dimension != signs_array_in->lwe_dimension)
PANIC("Cuda error: input lwe dimensions must be the same")
@@ -1817,11 +1818,11 @@ uint64_t scratch_cuda_apply_univariate_lut(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_apply_univariate_lut(CudaStreams streams,
CudaRadixCiphertextFFI *radix_lwe_out,
CudaRadixCiphertextFFI const *radix_lwe_in,
int_radix_lut<Torus> *mem, Torus *const *ksks,
int_radix_lut<Torus> *mem, KSTorus *const *ksks,
void *const *bsks) {
integer_radix_apply_univariate_lookup_table<Torus>(
@@ -1852,12 +1853,12 @@ uint64_t scratch_cuda_apply_many_univariate_lut(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_apply_many_univariate_lut(CudaStreams streams,
CudaRadixCiphertextFFI *radix_lwe_out,
CudaRadixCiphertextFFI const *radix_lwe_in,
int_radix_lut<Torus> *mem,
Torus *const *ksks, void *const *bsks,
KSTorus *const *ksks, void *const *bsks,
uint32_t num_many_lut,
uint32_t lut_stride) {
@@ -1888,12 +1889,12 @@ uint64_t scratch_cuda_apply_bivariate_lut(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_apply_bivariate_lut(CudaStreams streams,
CudaRadixCiphertextFFI *radix_lwe_out,
CudaRadixCiphertextFFI const *radix_lwe_in_1,
CudaRadixCiphertextFFI const *radix_lwe_in_2,
int_radix_lut<Torus> *mem, Torus *const *ksks,
int_radix_lut<Torus> *mem, KSTorus *const *ksks,
void *const *bsks, uint32_t num_radix_blocks,
uint32_t shift) {
@@ -1917,13 +1918,13 @@ uint64_t scratch_cuda_propagate_single_carry_inplace(
}
// This function perform the three steps of Thomas' new carry propagation
// includes the logic to extract overflow when requested
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_propagate_single_carry(CudaStreams streams,
CudaRadixCiphertextFFI *lwe_array,
CudaRadixCiphertextFFI *carry_out,
const CudaRadixCiphertextFFI *input_carries,
int_sc_prop_memory<Torus> *mem,
void *const *bsks, Torus *const *ksks,
void *const *bsks, KSTorus *const *ksks,
uint32_t requested_flag, uint32_t uses_carry) {
PUSH_RANGE("propagate sc")
auto num_radix_blocks = lwe_array->num_radix_blocks;
@@ -2019,12 +2020,12 @@ void host_propagate_single_carry(CudaStreams streams,
// This function perform the three steps of Thomas' new carry propagation
// includes the logic to extract overflow when requested
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_add_and_propagate_single_carry(
CudaStreams streams, CudaRadixCiphertextFFI *lhs_array,
const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out,
const CudaRadixCiphertextFFI *input_carries, int_sc_prop_memory<Torus> *mem,
void *const *bsks, Torus *const *ksks, uint32_t requested_flag,
void *const *bsks, KSTorus *const *ksks, uint32_t requested_flag,
uint32_t uses_carry) {
PUSH_RANGE("add & propagate sc")
if (lhs_array->num_radix_blocks != rhs_array->num_radix_blocks)
@@ -2181,13 +2182,13 @@ uint64_t scratch_cuda_integer_overflowing_sub(
// This function perform the three steps of Thomas' new borrow propagation
// includes the logic to extract overflow when requested
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_single_borrow_propagate(CudaStreams streams,
CudaRadixCiphertextFFI *lwe_array,
CudaRadixCiphertextFFI *overflow_block,
const CudaRadixCiphertextFFI *input_borrow,
int_borrow_prop_memory<Torus> *mem,
void *const *bsks, Torus *const *ksks,
void *const *bsks, KSTorus *const *ksks,
uint32_t num_groups,
uint32_t compute_overflow,
uint32_t uses_input_borrow) {
@@ -2294,12 +2295,13 @@ void host_single_borrow_propagate(CudaStreams streams,
/// num_radix_blocks corresponds to the number of blocks on which to apply the
/// LUT In scalar bitops we use a number of blocks that may be lower or equal to
/// the input and output numbers of blocks
template <typename InputTorus>
__host__ void integer_radix_apply_noise_squashing(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in,
int_noise_squashing_lut<InputTorus> *lut, void *const *bsks,
InputTorus *const *ksks) {
template <typename InputTorus, typename KSTorus>
__host__ void
integer_radix_apply_noise_squashing(CudaStreams streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in,
int_noise_squashing_lut<InputTorus> *lut,
void *const *bsks, KSTorus *const *ksks) {
PUSH_RANGE("apply noise squashing")
auto params = lut->params;

View File

@@ -3,13 +3,13 @@
#include "integer/bitwise_ops.cuh"
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
host_scalar_bitop(CudaStreams streams, CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI const *input,
Torus const *clear_blocks, Torus const *h_clear_blocks,
uint32_t num_clear_blocks, int_bitop_buffer<Torus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
if (output->num_radix_blocks != input->num_radix_blocks)
PANIC("Cuda error: input and output num radix blocks must be equal")

View File

@@ -24,14 +24,12 @@ Torus is_x_less_than_y_given_input_borrow(Torus last_x_block,
return output_sign_bit ^ overflow_flag;
}
template <typename Torus>
__host__ void scalar_compare_radix_blocks(CudaStreams streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI *lwe_array_in,
Torus *scalar_blocks,
int_comparison_buffer<Torus> *mem_ptr,
void *const *bsks, Torus *const *ksks,
uint32_t num_radix_blocks) {
template <typename Torus, typename KSTorus>
__host__ void scalar_compare_radix_blocks(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI *lwe_array_in, Torus *scalar_blocks,
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
KSTorus *const *ksks, uint32_t num_radix_blocks) {
if (num_radix_blocks == 0)
return;
@@ -84,13 +82,14 @@ __host__ void scalar_compare_radix_blocks(CudaStreams streams,
carry_modulus);
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void integer_radix_unsigned_scalar_difference_check(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in, Torus const *scalar_blocks,
Torus const *h_scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
std::function<Torus(Torus)> sign_handler_f, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks,
uint32_t num_scalar_blocks) {
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
PANIC("Cuda error: input lwe dimensions must be the same")
if (lwe_array_in->num_radix_blocks < num_radix_blocks)
@@ -322,13 +321,14 @@ __host__ void integer_radix_unsigned_scalar_difference_check(
}
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void integer_radix_signed_scalar_difference_check(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in, Torus const *scalar_blocks,
Torus const *h_scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
std::function<Torus(Torus)> sign_handler_f, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks,
uint32_t num_scalar_blocks) {
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
PANIC("Cuda error: input lwe dimensions must be the same")
@@ -640,13 +640,14 @@ __host__ void integer_radix_signed_scalar_difference_check(
}
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_scalar_difference_check(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in, Torus const *scalar_blocks,
Torus const *h_scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
std::function<Torus(Torus)> sign_handler_f, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks,
uint32_t num_scalar_blocks) {
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
PANIC("Cuda error: input lwe dimensions must be the same")
@@ -668,13 +669,13 @@ __host__ void host_scalar_difference_check(
}
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
host_scalar_maxmin(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in,
Torus const *scalar_blocks, Torus const *h_scalar_blocks,
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks,
KSTorus *const *ksks, uint32_t num_radix_blocks,
uint32_t num_scalar_blocks) {
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
@@ -712,12 +713,13 @@ host_scalar_maxmin(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
ksks);
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_scalar_equality_check(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in, Torus const *scalar_blocks,
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
KSTorus *const *ksks, uint32_t num_radix_blocks,
uint32_t num_scalar_blocks) {
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
PANIC("Cuda error: input and output lwe dimensions must be the same")

View File

@@ -24,11 +24,11 @@ __host__ uint64_t scratch_integer_unsigned_scalar_div_radix(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_integer_unsigned_scalar_div_radix(
CudaStreams streams, CudaRadixCiphertextFFI *numerator_ct,
int_unsigned_scalar_div_mem<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi) {
KSTorus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi) {
if (scalar_divisor_ffi->is_abs_divisor_one) {
return;
@@ -118,11 +118,11 @@ __host__ uint64_t scratch_integer_signed_scalar_div_radix(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_integer_signed_scalar_div_radix(
CudaStreams streams, CudaRadixCiphertextFFI *numerator_ct,
int_signed_scalar_div_mem<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
KSTorus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
uint32_t numerator_bits) {
if (scalar_divisor_ffi->is_abs_divisor_one) {
@@ -247,12 +247,12 @@ __host__ uint64_t scratch_integer_unsigned_scalar_div_rem_radix(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_integer_unsigned_scalar_div_rem_radix(
CudaStreams streams, CudaRadixCiphertextFFI *quotient_ct,
CudaRadixCiphertextFFI *remainder_ct,
int_unsigned_scalar_div_rem_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
KSTorus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
uint64_t const *divisor_has_at_least_one_set,
uint64_t const *decomposed_divisor, uint32_t const num_scalars_divisor,
Torus const *clear_blocks, Torus const *h_clear_blocks,
@@ -314,12 +314,12 @@ __host__ uint64_t scratch_integer_signed_scalar_div_rem_radix(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_integer_signed_scalar_div_rem_radix(
CudaStreams streams, CudaRadixCiphertextFFI *quotient_ct,
CudaRadixCiphertextFFI *remainder_ct,
int_signed_scalar_div_rem_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
KSTorus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
uint64_t const *divisor_has_at_least_one_set,
uint64_t const *decomposed_divisor, uint32_t const num_scalars_divisor,
uint32_t numerator_bits) {

View File

@@ -44,11 +44,11 @@ __host__ uint64_t scratch_cuda_scalar_mul(CudaStreams streams,
return size_tracker;
}
template <typename T>
template <typename T, typename KSTorus>
__host__ void host_integer_scalar_mul_radix(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array,
T const *decomposed_scalar, T const *has_at_least_one_set,
int_scalar_mul_buffer<T> *mem, void *const *bsks, T *const *ksks,
int_scalar_mul_buffer<T> *mem, void *const *bsks, KSTorus *const *ksks,
uint32_t message_modulus, uint32_t num_scalars) {
auto num_radix_blocks = lwe_array->num_radix_blocks;
@@ -167,11 +167,11 @@ __host__ void host_integer_small_scalar_mul_radix(
}
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
host_scalar_mul_high(CudaStreams streams, CudaRadixCiphertextFFI *ct,
int_scalar_mul_high_buffer<Torus> *mem_ptr,
Torus *const *ksks, void *const *bsks,
KSTorus *const *ksks, void *const *bsks,
const CudaScalarDivisorFFI *scalar_divisor_ffi) {
if (scalar_divisor_ffi->is_chosen_multiplier_zero) {
@@ -207,10 +207,10 @@ host_scalar_mul_high(CudaStreams streams, CudaRadixCiphertextFFI *ct,
host_trim_radix_blocks_lsb<Torus>(ct, tmp_ffi, streams);
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_signed_scalar_mul_high(
CudaStreams streams, CudaRadixCiphertextFFI *ct,
int_signed_scalar_mul_high_buffer<Torus> *mem_ptr, Torus *const *ksks,
int_signed_scalar_mul_high_buffer<Torus> *mem_ptr, KSTorus *const *ksks,
const CudaScalarDivisorFFI *scalar_divisor_ffi, void *const *bsks) {
if (scalar_divisor_ffi->is_chosen_multiplier_zero) {

View File

@@ -21,12 +21,12 @@ __host__ uint64_t scratch_cuda_scalar_rotate(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
host_scalar_rotate_inplace(CudaStreams streams,
CudaRadixCiphertextFFI *lwe_array, uint32_t n,
int_logical_scalar_shift_buffer<Torus> *mem,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
auto num_blocks = lwe_array->num_radix_blocks;
auto params = mem->params;

View File

@@ -22,11 +22,11 @@ __host__ uint64_t scratch_cuda_logical_scalar_shift(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_logical_scalar_shift_inplace(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array, uint32_t shift,
int_logical_scalar_shift_buffer<Torus> *mem, void *const *bsks,
Torus *const *ksks, uint32_t num_blocks) {
KSTorus *const *ksks, uint32_t num_blocks) {
if (lwe_array->num_radix_blocks < num_blocks)
PANIC("Cuda error: input does not have enough blocks")
@@ -126,11 +126,11 @@ __host__ uint64_t scratch_cuda_arithmetic_scalar_shift(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_arithmetic_scalar_shift_inplace(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array, uint32_t shift,
int_arithmetic_scalar_shift_buffer<Torus> *mem, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
auto num_blocks = lwe_array->num_radix_blocks;
auto params = mem->params;

View File

@@ -22,13 +22,13 @@ __host__ uint64_t scratch_cuda_shift_and_rotate(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void
host_shift_and_rotate_inplace(CudaStreams streams,
CudaRadixCiphertextFFI *lwe_array,
CudaRadixCiphertextFFI const *lwe_shift,
int_shift_and_rotate_buffer<Torus> *mem,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
cuda_set_device(streams.gpu_index(0));
if (lwe_array->num_radix_blocks != lwe_shift->num_radix_blocks)

View File

@@ -28,12 +28,12 @@ uint64_t scratch_cuda_sub_and_propagate_single_carry(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
void host_sub_and_propagate_single_carry(
CudaStreams streams, CudaRadixCiphertextFFI *lhs_array,
const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out,
const CudaRadixCiphertextFFI *input_carries,
int_sub_and_propagate<Torus> *mem, void *const *bsks, Torus *const *ksks,
int_sub_and_propagate<Torus> *mem, void *const *bsks, KSTorus *const *ksks,
uint32_t requested_flag, uint32_t uses_carry) {
host_negation<Torus>(streams, mem->neg_rhs_array, rhs_array,

View File

@@ -203,11 +203,11 @@ __global__ void tgemm(uint M, uint N, uint K, const Torus *A, const Torus *B,
// This code is adapted by generalizing the 1d block-tiling
// kernel from https://github.com/siboehm/SGEMM_CUDA
// to any matrix dimension
template <typename Torus, int BLOCK_SIZE, int THREADS>
template <typename Torus, typename IndicesType, int BLOCK_SIZE, int THREADS>
__global__ void tgemm_with_indices(uint M, uint N, uint K, const Torus *A,
const Torus *B, uint stride_B, Torus *C,
uint stride_C,
const Torus *__restrict__ C_indices) {
const IndicesType *__restrict__ C_indices) {
const int BM = BLOCK_SIZE;
const int BN = BLOCK_SIZE;

View File

@@ -137,7 +137,7 @@ TEST_P(KeyswitchMultiGPUTestPrimitives_u64, keyswitch) {
d_lwe_ct_out_array + (ptrdiff_t)(output_lwe_start_index);
// Execute keyswitch
cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
cuda_keyswitch_gemm_lwe_ciphertext_vector_64_64(
streams[gpu_i], gpu_i, d_lwe_ct_out, lwe_output_indexes,
d_lwe_ct_in_slice, lwe_input_indexes, d_ksk, input_lwe_dimension,
output_lwe_dimension, ksk_base_log, ksk_level, num_inputs,
@@ -248,7 +248,7 @@ TEST_P(KeyswitchTestPrimitives_u64, keyswitch) {
(ptrdiff_t)((r * SAMPLES * number_of_inputs + s * number_of_inputs) *
(input_lwe_dimension + 1));
// Execute keyswitch
cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
cuda_keyswitch_gemm_lwe_ciphertext_vector_64_64(
stream, gpu_index, (void *)d_lwe_ct_out_array,
(void *)lwe_output_indexes, (void *)d_lwe_ct_in,
(void *)lwe_input_indexes, (void *)d_ksk, input_lwe_dimension,

View File

@@ -2544,7 +2544,7 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn cuda_keyswitch_lwe_ciphertext_vector_32(
pub fn cuda_keyswitch_lwe_ciphertext_vector_64_64(
stream: *mut ffi::c_void,
gpu_index: u32,
lwe_array_out: *mut ffi::c_void,
@@ -2560,25 +2560,7 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
stream: *mut ffi::c_void,
gpu_index: u32,
lwe_array_out: *mut ffi::c_void,
lwe_output_indexes: *const ffi::c_void,
lwe_array_in: *const ffi::c_void,
lwe_input_indexes: *const ffi::c_void,
ksk: *const ffi::c_void,
lwe_dimension_in: u32,
lwe_dimension_out: u32,
base_log: u32,
level_count: u32,
num_samples: u32,
ks_tmp_buffer: *const ffi::c_void,
uses_trivial_indexes: bool,
);
}
unsafe extern "C" {
pub fn cuda_keyswitch_lwe_ciphertext_vector_64(
pub fn cuda_keyswitch_lwe_ciphertext_vector_64_32(
stream: *mut ffi::c_void,
gpu_index: u32,
lwe_array_out: *mut ffi::c_void,
@@ -2605,6 +2587,42 @@ unsafe extern "C" {
allocate_gpu_memory: bool,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_keyswitch_gemm_lwe_ciphertext_vector_64_64(
stream: *mut ffi::c_void,
gpu_index: u32,
lwe_array_out: *mut ffi::c_void,
lwe_output_indexes: *const ffi::c_void,
lwe_array_in: *const ffi::c_void,
lwe_input_indexes: *const ffi::c_void,
ksk: *const ffi::c_void,
lwe_dimension_in: u32,
lwe_dimension_out: u32,
base_log: u32,
level_count: u32,
num_samples: u32,
ks_tmp_buffer: *const ffi::c_void,
uses_trivial_indexes: bool,
);
}
unsafe extern "C" {
pub fn cuda_keyswitch_gemm_lwe_ciphertext_vector_64_32(
stream: *mut ffi::c_void,
gpu_index: u32,
lwe_array_out: *mut ffi::c_void,
lwe_output_indexes: *const ffi::c_void,
lwe_array_in: *const ffi::c_void,
lwe_input_indexes: *const ffi::c_void,
ksk: *const ffi::c_void,
lwe_dimension_in: u32,
lwe_dimension_out: u32,
base_log: u32,
level_count: u32,
num_samples: u32,
ks_tmp_buffer: *const ffi::c_void,
uses_trivial_indexes: bool,
);
}
unsafe extern "C" {
pub fn scratch_cuda_keyswitch_gemm_64(
stream: *mut ffi::c_void,
@@ -2676,6 +2694,16 @@ unsafe extern "C" {
gpu_memory_allocated: bool,
);
}
unsafe extern "C" {
pub fn cuda_closest_representable_64(
stream: *mut ffi::c_void,
gpu_index: u32,
input: *const ffi::c_void,
output: *mut ffi::c_void,
base_log: u32,
level_count: u32,
);
}
unsafe extern "C" {
pub fn cuda_negate_lwe_ciphertext_vector_32(
stream: *mut ffi::c_void,

View File

@@ -339,7 +339,10 @@ mod cuda {
use tfhe::core_crypto::prelude::*;
fn cuda_keyswitch<Scalar: UnsignedTorus + CastInto<usize> + CastFrom<u64> + Serialize>(
fn cuda_keyswitch_classical_and_gemm<
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<u64> + Serialize,
KeyswitchScalar: UnsignedTorus + CastFrom<Scalar>,
>(
criterion: &mut Criterion,
parameters: &[(String, CryptoParametersRecord<Scalar>)],
) {
@@ -361,27 +364,57 @@ mod cuda {
let ks_decomp_base_log = params.ks_base_log.unwrap();
let ks_decomp_level_count = params.ks_level.unwrap();
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key(
lwe_dimension,
&mut secret_generator,
);
let lwe_noise_distribution_ksk = match params.lwe_noise_distribution.unwrap() {
DynamicDistribution::Gaussian(gaussian_lwe_noise_distribution) => {
DynamicDistribution::<KeyswitchScalar>::new_gaussian(
gaussian_lwe_noise_distribution.standard_dev(),
)
}
DynamicDistribution::TUniform(uniform_lwe_noise_distribution) => {
DynamicDistribution::<KeyswitchScalar>::new_t_uniform(
match KeyswitchScalar::BITS {
32 => uniform_lwe_noise_distribution.bound_log2() - 32,
64 => uniform_lwe_noise_distribution.bound_log2(),
_ => panic!("Unsupported Keyswitch scalar input dtype"),
},
)
}
};
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut secret_generator,
);
let lwe_sk: LweSecretKeyOwned<KeyswitchScalar> =
allocate_and_generate_new_binary_lwe_secret_key(
lwe_dimension,
&mut secret_generator,
);
let glwe_sk: GlweSecretKeyOwned<KeyswitchScalar> =
allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut secret_generator,
);
let big_lwe_sk = glwe_sk.into_lwe_secret_key();
let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key(
&big_lwe_sk,
&lwe_sk,
ks_decomp_base_log,
ks_decomp_level_count,
params.lwe_noise_distribution.unwrap(),
lwe_noise_distribution_ksk,
CiphertextModulus::new_native(),
&mut encryption_generator,
);
let glwe_sk_64: GlweSecretKeyOwned<Scalar> =
allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut secret_generator,
);
let big_lwe_sk_64 = glwe_sk_64.into_lwe_secret_key();
let ciphertext_modulus_out = CiphertextModulus::<KeyswitchScalar>::new_native();
let cpu_keys: CpuKeys<_> = CpuKeysBuilder::new()
.keyswitch_key(ksk_big_to_small)
.build();
@@ -394,7 +427,7 @@ mod cuda {
let gpu_keys = CudaLocalKeys::from_cpu_keys(&cpu_keys, None, &streams);
let ct = allocate_and_encrypt_new_lwe_ciphertext(
&big_lwe_sk,
&big_lwe_sk_64,
Plaintext(Scalar::ONE),
params.lwe_noise_distribution.unwrap(),
CiphertextModulus::new_native(),
@@ -403,7 +436,7 @@ mod cuda {
let mut ct_gpu = CudaLweCiphertextList::from_lwe_ciphertext(&ct, &streams);
let output_ct = LweCiphertext::new(
Scalar::ZERO,
KeyswitchScalar::ZERO,
lwe_sk.lwe_dimension().to_lwe_size(),
CiphertextModulus::new_native(),
);
@@ -413,7 +446,10 @@ mod cuda {
let h_indexes = [Scalar::ZERO];
let cuda_indexes = CudaIndexes::new(&h_indexes, &streams, 0);
bench_id = format!("{bench_name}::{name}");
bench_id = format!(
"{bench_name}::latency::{:?}b::{name}",
KeyswitchScalar::BITS
);
{
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
@@ -457,7 +493,8 @@ mod cuda {
};
let gemm_str = if uses_gemm_ks { "gemm" } else { "classical" };
bench_id = format!(
"{bench_name}::throughput::{gemm_str}::{indices_str}_indices::{name}",
"{bench_name}::throughput::{:?}b::{gemm_str}::{indices_str}_indices::{name}",
KeyswitchScalar::BITS
);
let blocks: usize = 256;
@@ -483,7 +520,7 @@ mod cuda {
params.ciphertext_modulus.unwrap(),
);
encrypt_lwe_ciphertext_list(
&big_lwe_sk,
&big_lwe_sk_64,
&mut input_ct_list,
&plaintext_list,
params.lwe_noise_distribution.unwrap(),
@@ -504,10 +541,10 @@ mod cuda {
let output_cts = (0..gpu_count)
.map(|i| {
let output_ct_list = LweCiphertextList::new(
Scalar::ZERO,
KeyswitchScalar::ZERO,
lwe_sk.lwe_dimension().to_lwe_size(),
LweCiphertextCount(elements_per_stream),
params.ciphertext_modulus.unwrap(),
ciphertext_modulus_out,
);
CudaLweCiphertextList::from_lwe_ciphertext_list(
&output_ct_list,
@@ -584,7 +621,7 @@ mod cuda {
}
fn cuda_packing_keyswitch<
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<u64> + Serialize,
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<u64> + Serialize + CastInto<u32>,
>(
criterion: &mut Criterion,
parameters: &[(String, CryptoParametersRecord<Scalar>)],
@@ -791,9 +828,9 @@ mod cuda {
.zip(local_streams.par_iter())
.for_each(
|(
((i, input_lwe_list), output_glwe_list),
local_stream,
)| {
((i, input_lwe_list), output_glwe_list),
local_stream,
)| {
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64(
gpu_keys_vec[i].pksk.as_ref().unwrap(),
input_lwe_list,
@@ -826,7 +863,8 @@ mod cuda {
let mut criterion: Criterion<_> = (Criterion::default().sample_size(15))
.measurement_time(std::time::Duration::from_secs(60))
.configure_from_args();
cuda_keyswitch(&mut criterion, &benchmark_parameters());
cuda_keyswitch_classical_and_gemm::<u64, u32>(&mut criterion, &benchmark_parameters());
cuda_keyswitch_classical_and_gemm::<u64, u64>(&mut criterion, &benchmark_parameters());
cuda_packing_keyswitch(&mut criterion, &benchmark_parameters());
}
@@ -834,7 +872,8 @@ mod cuda {
let mut criterion: Criterion<_> = (Criterion::default().sample_size(15))
.measurement_time(std::time::Duration::from_secs(60))
.configure_from_args();
cuda_keyswitch(&mut criterion, &benchmark_parameters());
cuda_keyswitch_classical_and_gemm::<u64, u32>(&mut criterion, &benchmark_parameters());
cuda_keyswitch_classical_and_gemm::<u64, u64>(&mut criterion, &benchmark_parameters());
}
pub fn cuda_multi_bit_ks_group() {
@@ -844,7 +883,8 @@ mod cuda {
.into_iter()
.map(|(string, params, _)| (string, params))
.collect_vec();
cuda_keyswitch(&mut criterion, &multi_bit_parameters);
cuda_keyswitch_classical_and_gemm::<u64, u32>(&mut criterion, &multi_bit_parameters);
cuda_keyswitch_classical_and_gemm::<u64, u64>(&mut criterion, &multi_bit_parameters);
cuda_packing_keyswitch(&mut criterion, &multi_bit_parameters);
}
@@ -855,7 +895,8 @@ mod cuda {
.into_iter()
.map(|(string, params, _)| (string, params))
.collect_vec();
cuda_keyswitch(&mut criterion, &multi_bit_parameters);
cuda_keyswitch_classical_and_gemm::<u64, u32>(&mut criterion, &multi_bit_parameters);
cuda_keyswitch_classical_and_gemm::<u64, u64>(&mut criterion, &multi_bit_parameters);
}
}

View File

@@ -44,7 +44,7 @@ mod decomposer;
mod iter;
mod term;
#[cfg(test)]
mod tests;
pub(crate) mod tests;
/// The level of a given term of a decomposition.
///

View File

@@ -12,7 +12,7 @@ use std::fmt::Debug;
pub const NB_TESTS: usize = 10_000_000;
fn valid_decomposers<T: UnsignedInteger>() -> Vec<SignedDecomposer<T>> {
pub(crate) fn valid_decomposers<T: UnsignedInteger>() -> Vec<SignedDecomposer<T>> {
let mut valid_decomposers = vec![];
for base_log in (1..T::BITS).map(DecompositionBaseLog) {
for level_count in (1..T::BITS).map(DecompositionLevelCount) {

View File

@@ -46,7 +46,7 @@ impl<T: UnsignedInteger> TUniform<T> {
/// representation of integers.
pub const fn try_new(bound_log2: u32) -> Result<Self, &'static str> {
if (bound_log2 + 2) as usize > T::BITS {
return Err("Cannot create TUnfirorm: \
return Err("Cannot create TUniform: \
bound_log2 + 2 is greater than the current type's bit width");
}

View File

@@ -2,11 +2,11 @@ use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
use crate::core_crypto::gpu::vec::CudaVec;
use crate::core_crypto::gpu::{
keyswitch_async, keyswitch_async_gemm, scratch_cuda_keyswitch_gemm_64, CudaStreams,
cleanup_cuda_keyswitch_gemm_64, cuda_closest_representable_64, keyswitch_async,
keyswitch_async_gemm, scratch_cuda_keyswitch_gemm_64, CudaStreams,
};
use crate::core_crypto::prelude::UnsignedInteger;
use std::cmp::min;
use tfhe_cuda_backend::bindings::cleanup_cuda_keyswitch_gemm_64;
use tfhe_cuda_backend::ffi;
/// # Safety
@@ -14,10 +14,10 @@ use tfhe_cuda_backend::ffi;
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
/// be dropped until stream is synchronised
#[allow(clippy::too_many_arguments)]
pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
lwe_keyswitch_key: &CudaLweKeyswitchKey<Scalar>,
pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar, KSKScalar>(
lwe_keyswitch_key: &CudaLweKeyswitchKey<KSKScalar>,
input_lwe_ciphertext: &CudaLweCiphertextList<Scalar>,
output_lwe_ciphertext: &mut CudaLweCiphertextList<Scalar>,
output_lwe_ciphertext: &mut CudaLweCiphertextList<KSKScalar>,
input_indexes: &CudaVec<Scalar>,
output_indexes: &CudaVec<Scalar>,
uses_trivial_indices: bool,
@@ -25,6 +25,7 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
use_gemm_ks: bool,
) where
Scalar: UnsignedInteger,
KSKScalar: UnsignedInteger,
{
assert!(
lwe_keyswitch_key.input_key_lwe_size().to_lwe_dimension()
@@ -91,6 +92,7 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
);
if use_gemm_ks {
// Scratch allocations uses input LWE dtype for buffer size
cuda_scratch_keyswitch_lwe_ciphertext_async::<Scalar>(
streams,
std::ptr::addr_of_mut!(ks_tmp_buffer),
@@ -100,6 +102,7 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
true,
);
// Gemm KS can KS with input LWE dtype Scalar to output LWE dtype KSKScalar
keyswitch_async_gemm(
streams,
&mut output_lwe_ciphertext.0.d_vec,
@@ -139,10 +142,10 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
}
#[allow(clippy::too_many_arguments)]
pub fn cuda_keyswitch_lwe_ciphertext<Scalar>(
lwe_keyswitch_key: &CudaLweKeyswitchKey<Scalar>,
pub fn cuda_keyswitch_lwe_ciphertext<Scalar, KSKScalar>(
lwe_keyswitch_key: &CudaLweKeyswitchKey<KSKScalar>,
input_lwe_ciphertext: &CudaLweCiphertextList<Scalar>,
output_lwe_ciphertext: &mut CudaLweCiphertextList<Scalar>,
output_lwe_ciphertext: &mut CudaLweCiphertextList<KSKScalar>,
input_indexes: &CudaVec<Scalar>,
output_indexes: &CudaVec<Scalar>,
uses_trivial_indices: bool,
@@ -150,6 +153,7 @@ pub fn cuda_keyswitch_lwe_ciphertext<Scalar>(
use_gemm_ks: bool,
) where
Scalar: UnsignedInteger,
KSKScalar: UnsignedInteger,
{
unsafe {
cuda_keyswitch_lwe_ciphertext_async(
@@ -207,3 +211,25 @@ pub unsafe fn cleanup_cuda_keyswitch_async<Scalar>(
allocate_gpu_memory,
);
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
/// be dropped until stream is synchronized
pub unsafe fn cuda_closest_representable<Scalar>(
streams: &CudaStreams,
input: &CudaVec<Scalar>,
output: &mut CudaVec<Scalar>,
base_log: u32,
level_count: u32,
) where
Scalar: UnsignedInteger,
{
cuda_closest_representable_64(
streams.ptr[0],
streams.gpu_indexes[0].get(),
input.as_c_ptr(0),
output.as_mut_c_ptr(0),
base_log,
level_count,
);
}

View File

@@ -1,8 +1,11 @@
use super::*;
use crate::core_crypto::commons::test_tools::any_uint;
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
use crate::core_crypto::gpu::vec::{CudaVec, GpuIndex};
use crate::core_crypto::gpu::{cuda_keyswitch_lwe_ciphertext, CudaStreams};
use crate::core_crypto::gpu::{
cuda_closest_representable, cuda_keyswitch_lwe_ciphertext, CudaStreams,
};
use crate::core_crypto::prelude::misc::check_encrypted_content_respects_mod;
use itertools::Itertools;
use rand::seq::SliceRandom;
@@ -61,6 +64,11 @@ fn lwe_encrypt_ks_decrypt_custom_mod_mb<Scalar: UnsignedTorus + CastFrom<usize>>
}
#[allow(clippy::too_many_arguments)]
// Use for both the Multi-Bit and Classic PBS setting.
// Tests GEMM and Classic KS:
// - tests that keyswitched LWE is decrypted correctly
// - tests that GEMM and Classic KS are bit-wise equivalent
// - tests that only a subset of LWEs can be keyswitched
fn base_lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
lwe_dimension: LweDimension,
lwe_noise_distribution: DynamicDistribution<Scalar>,
@@ -166,6 +174,7 @@ fn base_lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize
} else {
num_blocks
};
let lwe_indexes_usize = (0..num_blocks).collect_vec();
let mut lwe_indexes = lwe_indexes_usize.clone();
@@ -227,20 +236,35 @@ fn base_lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize
assert_eq!(output_ct_list_gpu.lwe_ciphertext_count().0, num_blocks);
// The output has `n_blocks` LWEs but only some are actually set - those
// that correspond to output indices. We loop over all LWEs in the output buffer
let output_ct_list_cpu = output_ct_list_gpu_gemm.to_lwe_ciphertext_list(&stream);
let output_ct_list_cpu_gemm = output_ct_list_gpu_gemm.to_lwe_ciphertext_list(&stream);
output_ct_list_gpu
.to_lwe_ciphertext_list(&stream)
.iter()
.zip(0..num_blocks)
.for_each(|(lwe_ct_out, i)| {
let lwe_ct_out_gemm = output_ct_list_cpu_gemm.get(i);
let tmp_classical = lwe_ct_out.into_container();
let tmp_gemm = lwe_ct_out_gemm.into_container();
// Compare bitwise the output of classical KS and GEMM KS
for (v1, v2) in tmp_classical.iter().zip(tmp_gemm.iter()) {
assert_eq!(*v1, *v2);
}
assert!(check_encrypted_content_respects_mod(
&lwe_ct_out,
ciphertext_modulus
));
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &lwe_ct_out);
// Check GEMM vs Classical bitwise equivalent
let tmp_gemm = lwe_ct_out_gemm.into_container();
for (v1, v2) in tmp_classical.iter().zip(tmp_gemm.iter()) {
assert_eq!(v1, v2);
}
let lwe_ct_out_gemm = output_ct_list_cpu.get(i);
// Check GEMM & Classical KS decrypt to reference value
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &lwe_ct_out);
let decrypted_gemm = decrypt_lwe_ciphertext(&lwe_sk, &lwe_ct_out_gemm);
let decoded = round_decode(decrypted.0, delta) % msg_modulus;
@@ -253,5 +277,306 @@ fn base_lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize
}
}
#[allow(clippy::too_many_arguments)]
fn lwe_encrypt_ks_decrypt_ks32_common<
Scalar: UnsignedTorus + CastFrom<usize> + CastInto<KSKScalar> + CastFrom<KSKScalar>,
KSKScalar: UnsignedTorus + CastFrom<Scalar>,
>(
lwe_dimension: LweDimension,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
lwe_noise_distribution: DynamicDistribution<Scalar>,
ks_decomp_base_log: DecompositionBaseLog,
ks_decomp_level_count: DecompositionLevelCount,
message_modulus_log: MessageModulusLog,
ciphertext_modulus: CiphertextModulus<Scalar>,
) {
let input_encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
let output_ciphertext_modulus = CiphertextModulus::<KSKScalar>::new_native();
let output_encoding_with_padding = get_encoding_with_padding(output_ciphertext_modulus);
let lwe_noise_distribution_u32 = match lwe_noise_distribution {
DynamicDistribution::Gaussian(gaussian_lwe_noise_distribution) => {
DynamicDistribution::<KSKScalar>::new_gaussian(
gaussian_lwe_noise_distribution.standard_dev(),
)
}
DynamicDistribution::TUniform(uniform_lwe_noise_distribution) => {
DynamicDistribution::<KSKScalar>::new_t_uniform(
uniform_lwe_noise_distribution.bound_log2(),
)
}
};
let input_msg_modulus = Scalar::ONE << message_modulus_log.0;
let output_msg_modulus = KSKScalar::ONE << message_modulus_log.0;
let input_delta = input_encoding_with_padding / input_msg_modulus;
let output_delta = output_encoding_with_padding / output_msg_modulus;
let stream = CudaStreams::new_single_gpu(GpuIndex::new(0));
let mut rsc = TestResources::new();
const NB_TESTS: usize = 10;
let mut msg = input_msg_modulus;
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key(
lwe_dimension,
&mut rsc.secret_random_generator,
);
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key::<Scalar, _>(
glwe_dimension,
polynomial_size,
&mut rsc.secret_random_generator,
);
let big_lwe_sk_u32 = LweSecretKey::from_container(
glwe_sk
.as_ref()
.iter()
.copied()
.map(|x| x.cast_into())
.collect::<Vec<KSKScalar>>(),
);
let big_lwe_sk = glwe_sk.into_lwe_secret_key();
let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key(
&big_lwe_sk_u32,
&lwe_sk,
ks_decomp_base_log,
ks_decomp_level_count,
lwe_noise_distribution_u32,
output_ciphertext_modulus,
&mut rsc.encryption_random_generator,
);
assert!(check_encrypted_content_respects_mod(
&ksk_big_to_small,
output_ciphertext_modulus
));
let d_ksk_big_to_small =
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&ksk_big_to_small, &stream);
while msg != Scalar::ZERO {
msg = msg.wrapping_sub(Scalar::ONE);
for _ in 0..NB_TESTS {
let plaintext = Plaintext(msg * input_delta);
let ct = allocate_and_encrypt_new_lwe_ciphertext(
&big_lwe_sk, //64b
plaintext,
lwe_noise_distribution,
ciphertext_modulus,
&mut rsc.encryption_random_generator,
);
assert!(check_encrypted_content_respects_mod(
&ct,
ciphertext_modulus
));
let mut output_ct_ref = LweCiphertext::new(
KSKScalar::ZERO,
lwe_sk.lwe_dimension().to_lwe_size(),
output_ciphertext_modulus,
);
keyswitch_lwe_ciphertext_with_scalar_change(&ksk_big_to_small, &ct, &mut output_ct_ref);
// 32b, 64b, 32b
let decrypted_cpu = decrypt_lwe_ciphertext(&lwe_sk, &output_ct_ref);
let decoded_cpu: KSKScalar =
round_decode(decrypted_cpu.0, output_delta) % output_msg_modulus;
assert_eq!(msg, decoded_cpu.cast_into());
let d_ct = CudaLweCiphertextList::from_lwe_ciphertext(&ct, &stream);
let mut d_output_ct = CudaLweCiphertextList::new(
ksk_big_to_small.output_key_lwe_dimension(),
LweCiphertextCount(1),
output_ciphertext_modulus,
&stream,
);
let mut d_output_ct_gemm = CudaLweCiphertextList::new(
ksk_big_to_small.output_key_lwe_dimension(),
LweCiphertextCount(1),
output_ciphertext_modulus,
&stream,
);
let num_blocks = d_ct.0.lwe_ciphertext_count.0;
let lwe_indexes_usize = (0..num_blocks).collect_vec();
let lwe_indexes = lwe_indexes_usize
.iter()
.map(|&x| <usize as CastInto<Scalar>>::cast_into(x))
.collect_vec();
let mut d_input_indexes =
unsafe { CudaVec::<Scalar>::new_async(num_blocks, &stream, 0) };
let mut d_output_indexes =
unsafe { CudaVec::<Scalar>::new_async(num_blocks, &stream, 0) };
unsafe { d_input_indexes.copy_from_cpu_async(&lwe_indexes, &stream, 0) };
unsafe { d_output_indexes.copy_from_cpu_async(&lwe_indexes, &stream, 0) };
cuda_keyswitch_lwe_ciphertext(
&d_ksk_big_to_small,
&d_ct,
&mut d_output_ct,
&d_input_indexes,
&d_output_indexes,
true,
&stream,
false,
);
cuda_keyswitch_lwe_ciphertext(
&d_ksk_big_to_small,
&d_ct,
&mut d_output_ct_gemm,
&d_input_indexes,
&d_output_indexes,
true,
&stream,
true,
);
let output_ct = d_output_ct.into_lwe_ciphertext(&stream);
let tmp = output_ct.clone().into_container();
let tmp_cpu = output_ct_ref.clone().into_container();
for (v1, v2) in tmp.iter().zip(tmp_cpu.iter()) {
assert_eq!(v1, v2);
}
let output_ct_gemm = d_output_ct_gemm.into_lwe_ciphertext(&stream);
let tmp = output_ct_gemm.clone().into_container();
let tmp_cpu = output_ct_ref.clone().into_container();
for (v1, v2) in tmp.iter().zip(tmp_cpu.iter()) {
assert_eq!(v1, v2);
}
assert!(check_encrypted_content_respects_mod(
&output_ct,
output_ciphertext_modulus
));
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &output_ct);
let decoded = round_decode(decrypted.0, output_delta) % output_msg_modulus;
assert_eq!(msg, decoded.cast_into());
}
}
}
fn lwe_encrypt_ks_decrypt_custom_mod_ks32<
Scalar: UnsignedTorus + CastFrom<usize> + CastInto<u32> + CastFrom<u32>,
>(
params: &MultiBitTestKS32Params<Scalar>,
) where
u32: CastFrom<Scalar>,
{
lwe_encrypt_ks_decrypt_ks32_common::<Scalar, u32>(
params.lwe_dimension,
params.glwe_dimension,
params.polynomial_size,
params.lwe_noise_distribution,
params.ks_base_log,
params.ks_level,
params.message_modulus_log,
params.ciphertext_modulus,
);
}
fn test_util_closest_representable_on_gpu(
value: u64,
base_log: DecompositionBaseLog,
level_count: DecompositionLevelCount,
) -> u64 {
let stream = CudaStreams::new_single_gpu(GpuIndex::new(0));
let h_input: Vec<u64> = vec![value];
let mut d_input = unsafe { CudaVec::<u64>::new_async(1, &stream, 0) };
unsafe { d_input.copy_from_cpu_async(&h_input, &stream, 0) };
let mut d_output = unsafe { CudaVec::<u64>::new_async(1, &stream, 0) };
unsafe {
cuda_closest_representable(
&stream,
&d_input,
&mut d_output,
base_log.0 as u32,
level_count.0 as u32,
);
}
let mut h_output: Vec<u64> = vec![0];
unsafe {
d_output.copy_to_cpu_async(&mut h_output, &stream, 0);
}
stream.synchronize();
*h_output.first().unwrap()
}
#[test]
fn test_closest_representable_gpu() {
let base_log = DecompositionBaseLog(17);
let level_count = DecompositionLevelCount(3);
let decomposer = SignedDecomposer::new(base_log, level_count);
// This value triggers a negative state at the start of the decomposition, invalid code using
// logic shift will wrongly compute an intermediate value by not keeping the sign of the
// state on the last level if base_log * (level_count + 1) > Scalar::BITS, the logic shift will
// shift in 0s instead of the 1s to keep the sign information
let val: u64 = 0x8000_00e3_55b0_c827;
let rounded = decomposer.closest_representable(val);
let recomp = decomposer.recompose(decomposer.decompose(val)).unwrap();
let rounded_gpu = test_util_closest_representable_on_gpu(val, base_log, level_count);
assert_eq!(rounded, recomp);
assert_eq!(rounded_gpu, rounded);
}
#[test]
fn test_round_to_closest_representable_gpu() {
let runs_per_decomposer = 100;
let valid_decomposers =
crate::core_crypto::commons::math::decomposition::tests::valid_decomposers::<u64>();
for decomposer in valid_decomposers {
// Checks that the closest representable computed on GPU is the same as on CPU
for _ in 0..runs_per_decomposer {
let val = any_uint::<u64>();
let rounded = test_util_closest_representable_on_gpu(
val,
decomposer.base_log(),
decomposer.level_count(),
);
let epsilon =
(1u64 << (64 - (decomposer.base_log * decomposer.level_count) - 1)) / 2u64;
// Adding/removing an epsilon should not change the closest representable
assert_eq!(
rounded,
decomposer.closest_representable(rounded.wrapping_add(epsilon))
);
assert_eq!(
rounded,
decomposer.closest_representable(rounded.wrapping_sub(epsilon))
);
}
}
}
create_gpu_parameterized_test!(lwe_encrypt_ks_decrypt_custom_mod);
create_gpu_multi_bit_parameterized_test!(lwe_encrypt_ks_decrypt_custom_mod_mb);
create_gpu_multi_bit_ks32_parameterized_test!(lwe_encrypt_ks_decrypt_custom_mod_ks32);

View File

@@ -13,6 +13,7 @@ mod lwe_programmable_bootstrapping;
mod lwe_programmable_bootstrapping_128;
mod modulus_switch;
mod noise_distribution;
mod params;
pub struct CudaPackingKeySwitchKeys<Scalar: UnsignedInteger> {
pub lwe_sk: LweSecretKey<Vec<Scalar>>,
@@ -20,6 +21,24 @@ pub struct CudaPackingKeySwitchKeys<Scalar: UnsignedInteger> {
pub pksk: CudaLwePackingKeyswitchKey<Scalar>,
}
pub const MULTI_BIT_2_2_2_KS32_PARAMS: MultiBitTestKS32Params<u64> = MultiBitTestKS32Params {
lwe_dimension: LweDimension(920),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
lwe_noise_distribution: DynamicDistribution::new_t_uniform(13),
glwe_noise_distribution: DynamicDistribution::new_t_uniform(17),
pbs_base_log: DecompositionBaseLog(22),
pbs_level: DecompositionLevelCount(1),
ks_base_log: DecompositionBaseLog(3),
ks_level: DecompositionLevelCount(5),
message_modulus_log: MessageModulusLog(2),
log2_p_fail: -134.345,
ciphertext_modulus: CiphertextModulus::new_native(),
encryption_key_choice: EncryptionKeyChoice::Big,
grouping_factor: LweBskGroupingFactor(4),
deterministic_execution: false,
};
// Macro to generate tests for all parameter sets
macro_rules! create_gpu_parameterized_test{
($name:ident { $($param:ident),* }) => {
@@ -60,6 +79,27 @@ macro_rules! create_gpu_multi_bit_parameterized_test{
});
};
}
macro_rules! create_gpu_multi_bit_ks32_parameterized_test{
($name:ident { $($param:ident),* }) => {
::paste::paste! {
$(
#[test]
fn [<test_gpu_ $name _ $param:lower>]() {
$name(&$param)
}
)*
}
};
($name:ident)=> {
create_gpu_multi_bit_ks32_parameterized_test!($name
{
MULTI_BIT_2_2_2_KS32_PARAMS
});
};
}
use crate::core_crypto::gpu::algorithms::test::params::MultiBitTestKS32Params;
use crate::core_crypto::gpu::lwe_packing_keyswitch_key::CudaLwePackingKeyswitchKey;
use {create_gpu_multi_bit_parameterized_test, create_gpu_parameterized_test};
use {
create_gpu_multi_bit_ks32_parameterized_test, create_gpu_multi_bit_parameterized_test,
create_gpu_parameterized_test,
};

View File

@@ -0,0 +1,26 @@
use crate::core_crypto::commons::math::random::{Deserialize, Serialize};
use crate::core_crypto::prelude::{
CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution,
GlweDimension, LweBskGroupingFactor, LweDimension, MessageModulusLog, PolynomialSize,
UnsignedInteger,
};
use crate::shortint::EncryptionKeyChoice;
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub struct MultiBitTestKS32Params<Scalar: UnsignedInteger> {
pub lwe_dimension: LweDimension,
pub glwe_dimension: GlweDimension,
pub polynomial_size: PolynomialSize,
pub lwe_noise_distribution: DynamicDistribution<Scalar>,
pub glwe_noise_distribution: DynamicDistribution<Scalar>,
pub pbs_base_log: DecompositionBaseLog,
pub pbs_level: DecompositionLevelCount,
pub ks_base_log: DecompositionBaseLog,
pub ks_level: DecompositionLevelCount,
pub message_modulus_log: MessageModulusLog,
pub log2_p_fail: f64,
pub ciphertext_modulus: CiphertextModulus<Scalar>,
pub encryption_key_choice: EncryptionKeyChoice,
pub grouping_factor: LweBskGroupingFactor,
pub deterministic_execution: bool,
}

View File

@@ -485,37 +485,58 @@ pub fn get_programmable_bootstrap_multi_bit_size_on_gpu(
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
/// required
#[allow(clippy::too_many_arguments)]
pub unsafe fn keyswitch_async_gemm<T: UnsignedInteger>(
pub unsafe fn keyswitch_async_gemm<T: UnsignedInteger, KST: UnsignedInteger>(
streams: &CudaStreams,
lwe_array_out: &mut CudaVec<T>,
lwe_array_out: &mut CudaVec<KST>,
lwe_out_indexes: &CudaVec<T>,
lwe_array_in: &CudaVec<T>,
lwe_in_indexes: &CudaVec<T>,
input_lwe_dimension: LweDimension,
output_lwe_dimension: LweDimension,
keyswitch_key: &CudaVec<T>,
keyswitch_key: &CudaVec<KST>,
base_log: DecompositionBaseLog,
l_gadget: DecompositionLevelCount,
num_samples: u32,
ks_tmp_buffer: *const ffi::c_void,
uses_trivial_indices: bool,
) {
cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
streams.ptr[0],
streams.gpu_indexes[0].get(),
lwe_array_out.as_mut_c_ptr(0),
lwe_out_indexes.as_c_ptr(0),
lwe_array_in.as_c_ptr(0),
lwe_in_indexes.as_c_ptr(0),
keyswitch_key.as_c_ptr(0),
input_lwe_dimension.0 as u32,
output_lwe_dimension.0 as u32,
base_log.0 as u32,
l_gadget.0 as u32,
num_samples,
ks_tmp_buffer,
uses_trivial_indices,
);
if TypeId::of::<KST>() == TypeId::of::<u32>() {
cuda_keyswitch_gemm_lwe_ciphertext_vector_64_32(
streams.ptr[0],
streams.gpu_indexes[0].get(),
lwe_array_out.as_mut_c_ptr(0),
lwe_out_indexes.as_c_ptr(0),
lwe_array_in.as_c_ptr(0),
lwe_in_indexes.as_c_ptr(0),
keyswitch_key.as_c_ptr(0),
input_lwe_dimension.0 as u32,
output_lwe_dimension.0 as u32,
base_log.0 as u32,
l_gadget.0 as u32,
num_samples,
ks_tmp_buffer,
uses_trivial_indices,
);
} else if TypeId::of::<KST>() == TypeId::of::<u64>() {
cuda_keyswitch_gemm_lwe_ciphertext_vector_64_64(
streams.ptr[0],
streams.gpu_indexes[0].get(),
lwe_array_out.as_mut_c_ptr(0),
lwe_out_indexes.as_c_ptr(0),
lwe_array_in.as_c_ptr(0),
lwe_in_indexes.as_c_ptr(0),
keyswitch_key.as_c_ptr(0),
input_lwe_dimension.0 as u32,
output_lwe_dimension.0 as u32,
base_log.0 as u32,
l_gadget.0 as u32,
num_samples,
ks_tmp_buffer,
uses_trivial_indices,
);
} else {
panic!("Unknown LWE GEMM KS dtype of size {}B", size_of::<KST>());
}
}
/// Keyswitch on a vector of LWE ciphertexts. Better for small batches of LWEs
@@ -525,33 +546,50 @@ pub unsafe fn keyswitch_async_gemm<T: UnsignedInteger>(
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
/// required
#[allow(clippy::too_many_arguments)]
pub unsafe fn keyswitch_async<T: UnsignedInteger>(
pub unsafe fn keyswitch_async<T: UnsignedInteger, KT: UnsignedInteger>(
streams: &CudaStreams,
lwe_array_out: &mut CudaVec<T>,
lwe_array_out: &mut CudaVec<KT>,
lwe_out_indexes: &CudaVec<T>,
lwe_array_in: &CudaVec<T>,
lwe_in_indexes: &CudaVec<T>,
input_lwe_dimension: LweDimension,
output_lwe_dimension: LweDimension,
keyswitch_key: &CudaVec<T>,
keyswitch_key: &CudaVec<KT>,
base_log: DecompositionBaseLog,
l_gadget: DecompositionLevelCount,
num_samples: u32,
) {
cuda_keyswitch_lwe_ciphertext_vector_64(
streams.ptr[0],
streams.gpu_indexes[0].get(),
lwe_array_out.as_mut_c_ptr(0),
lwe_out_indexes.as_c_ptr(0),
lwe_array_in.as_c_ptr(0),
lwe_in_indexes.as_c_ptr(0),
keyswitch_key.as_c_ptr(0),
input_lwe_dimension.0 as u32,
output_lwe_dimension.0 as u32,
base_log.0 as u32,
l_gadget.0 as u32,
num_samples,
);
if TypeId::of::<KT>() == TypeId::of::<u32>() {
cuda_keyswitch_lwe_ciphertext_vector_64_32(
streams.ptr[0],
streams.gpu_indexes[0].get(),
lwe_array_out.as_mut_c_ptr(0),
lwe_out_indexes.as_c_ptr(0),
lwe_array_in.as_c_ptr(0),
lwe_in_indexes.as_c_ptr(0),
keyswitch_key.as_c_ptr(0),
input_lwe_dimension.0 as u32,
output_lwe_dimension.0 as u32,
base_log.0 as u32,
l_gadget.0 as u32,
num_samples,
);
} else if TypeId::of::<KT>() == TypeId::of::<u64>() {
cuda_keyswitch_lwe_ciphertext_vector_64_64(
streams.ptr[0],
streams.gpu_indexes[0].get(),
lwe_array_out.as_mut_c_ptr(0),
lwe_out_indexes.as_c_ptr(0),
lwe_array_in.as_c_ptr(0),
lwe_in_indexes.as_c_ptr(0),
keyswitch_key.as_c_ptr(0),
input_lwe_dimension.0 as u32,
output_lwe_dimension.0 as u32,
base_log.0 as u32,
l_gadget.0 as u32,
num_samples,
);
}
}
/// Convert keyswitch key
///

View File

@@ -487,7 +487,10 @@ impl CudaServerKey {
}
pub fn gpu_indexes(&self) -> &[GpuIndex] {
&self.key.key.key_switching_key.d_vec.gpu_indexes
match &self.key.key.key_switching_key {
CudaDynamicKeyswitchingKey::KeySwitch32(ksk_32) => ksk_32.d_vec.gpu_indexes.as_slice(),
CudaDynamicKeyswitchingKey::Standard(std_key) => std_key.d_vec.gpu_indexes.as_slice(),
}
}
pub(in crate::high_level_api) fn re_randomization_cpk_casting_key(
&self,
@@ -611,6 +614,8 @@ use crate::high_level_api::keys::inner::IntegerServerKeyConformanceParams;
#[cfg(feature = "gpu")]
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial;
#[cfg(feature = "gpu")]
use crate::integer::gpu::server_key::CudaDynamicKeyswitchingKey;
impl ParameterSetConformant for ServerKey {
type ParameterSet = IntegerServerKeyConformanceParams;

View File

@@ -10,7 +10,7 @@ use crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaExpandable;
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
use crate::integer::gpu::ciphertext::{CudaRadixCiphertext, CudaVec, KsType, LweDimension};
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
use crate::integer::gpu::{cuda_backend_expand, PBSType};
use crate::shortint::ciphertext::CompactCiphertextList;
use crate::shortint::parameters::{
@@ -404,7 +404,12 @@ impl CudaFlattenedVecCompactCiphertextList {
let d_input = &self.d_flattened_vec;
let casting_key = key.key_switching_key_material;
let sks = key.dest_server_key;
let computing_ks_key = &key.dest_server_key.key_switching_key;
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) =
&key.dest_server_key.key_switching_key
else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let casting_key_type: KsType = casting_key.destination_key.into();

View File

@@ -365,13 +365,17 @@ pub(crate) unsafe fn cuda_backend_scalar_addition_assign<T: UnsignedInteger>(
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_unchecked_scalar_mul<T: UnsignedInteger, B: Numeric>(
pub(crate) unsafe fn cuda_backend_unchecked_scalar_mul<
T: UnsignedInteger,
KST: UnsignedInteger,
B: Numeric,
>(
streams: &CudaStreams,
lwe_array: &mut CudaRadixCiphertext,
decomposed_scalar: &[T],
has_at_least_one_set: &[T],
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<u64>,
keyswitch_key: &CudaVec<KST>,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
@@ -1736,13 +1740,17 @@ pub(crate) fn cuda_backend_get_bitop_size_on_gpu(
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_unchecked_scalar_bitop_assign<T: UnsignedInteger, B: Numeric>(
pub(crate) unsafe fn cuda_backend_unchecked_scalar_bitop_assign<
T: UnsignedInteger,
KST: UnsignedInteger,
B: Numeric,
>(
streams: &CudaStreams,
radix_lwe: &mut CudaRadixCiphertext,
clear_blocks: &CudaVec<T>,
h_clear_blocks: &[T],
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
keyswitch_key: &CudaVec<KST>,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
@@ -2102,14 +2110,18 @@ pub(crate) fn cuda_backend_get_comparison_size_on_gpu(
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_unchecked_scalar_comparison<T: UnsignedInteger, B: Numeric>(
pub(crate) unsafe fn cuda_backend_unchecked_scalar_comparison<
T: UnsignedInteger,
KST: UnsignedInteger,
B: Numeric,
>(
streams: &CudaStreams,
radix_lwe_out: &mut CudaRadixCiphertext,
radix_lwe_in: &CudaRadixCiphertext,
scalar_blocks: &CudaVec<T>,
h_scalar_blocks: &[T],
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
keyswitch_key: &CudaVec<KST>,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
@@ -3032,7 +3044,11 @@ pub(crate) unsafe fn cuda_backend_grouped_oprf<B: Numeric>(
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_grouped_oprf_custom_range<T: UnsignedInteger, B: Numeric>(
pub(crate) unsafe fn cuda_backend_grouped_oprf_custom_range<
T: UnsignedInteger,
B: Numeric,
KST: Numeric,
>(
streams: &CudaStreams,
radix_lwe_out: &mut CudaRadixCiphertext,
num_blocks_intermediate: u32,
@@ -3041,7 +3057,7 @@ pub(crate) unsafe fn cuda_backend_grouped_oprf_custom_range<T: UnsignedInteger,
has_at_least_one_set: &[T],
shift: u32,
bootstrapping_key: &CudaVec<B>,
key_switching_key: &CudaVec<u64>,
key_switching_key: &CudaVec<KST>,
lwe_dimension: LweDimension,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
@@ -5916,7 +5932,11 @@ pub(crate) unsafe fn cuda_backend_unchecked_partial_sum_ciphertexts_assign<
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_apply_univariate_lut<T: UnsignedInteger, B: Numeric>(
pub(crate) unsafe fn cuda_backend_apply_univariate_lut<
T: UnsignedInteger,
KST: UnsignedInteger,
B: Numeric,
>(
streams: &CudaStreams,
output: &mut CudaSliceMut<T>,
output_degrees: &mut Vec<u64>,
@@ -5925,7 +5945,7 @@ pub(crate) unsafe fn cuda_backend_apply_univariate_lut<T: UnsignedInteger, B: Nu
input_lut: &[T],
lut_degree: u64,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
keyswitch_key: &CudaVec<KST>,
lwe_dimension: LweDimension,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
@@ -6023,7 +6043,11 @@ pub(crate) unsafe fn cuda_backend_apply_univariate_lut<T: UnsignedInteger, B: Nu
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_apply_many_univariate_lut<T: UnsignedInteger, B: Numeric>(
pub(crate) unsafe fn cuda_backend_apply_many_univariate_lut<
T: UnsignedInteger,
KST: UnsignedInteger,
B: Numeric,
>(
streams: &CudaStreams,
output: &mut CudaSliceMut<T>,
output_degrees: &mut Vec<u64>,
@@ -6032,7 +6056,7 @@ pub(crate) unsafe fn cuda_backend_apply_many_univariate_lut<T: UnsignedInteger,
input_lut: &[T],
lut_degree: u64,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
keyswitch_key: &CudaVec<KST>,
lwe_dimension: LweDimension,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
@@ -7250,14 +7274,18 @@ pub(crate) unsafe fn cuda_backend_extend_radix_with_trivial_zero_blocks_msb(
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_noise_squashing<T: UnsignedInteger, B: Numeric>(
pub(crate) unsafe fn cuda_backend_noise_squashing<
T: UnsignedInteger,
KST: UnsignedInteger,
B: Numeric,
>(
streams: &CudaStreams,
output: &mut CudaSliceMut<T>,
output_degrees: &mut Vec<u64>,
output_noise_levels: &mut Vec<u64>,
input: &CudaSlice<u64>,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<u64>,
keyswitch_key: &CudaVec<KST>,
lwe_dimension: LweDimension,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
@@ -7369,12 +7397,12 @@ pub(crate) unsafe fn cuda_backend_noise_squashing<T: UnsignedInteger, B: Numeric
/// that were inside that vector of compact list. Handling the input this way removes the need
/// to process multiple compact lists separately, simplifying GPU-based operations. The variable
/// name `lwe_flattened_compact_array_in` makes this intent explicit.
pub(crate) unsafe fn cuda_backend_expand<T: UnsignedInteger, B: Numeric>(
pub(crate) unsafe fn cuda_backend_expand<T: UnsignedInteger, KST: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
lwe_array_out: &mut CudaLweCiphertextList<T>,
lwe_flattened_compact_array_in: &CudaVec<T>,
bootstrapping_key: &CudaVec<B>,
computing_ks_key: &CudaVec<T>,
computing_ks_key: &CudaVec<KST>,
casting_key: &CudaVec<T>,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,

View File

@@ -36,14 +36,18 @@ impl<Scalar: UnsignedInteger> CudaBootstrappingKey<Scalar> {
}
}
pub enum CudaDynamicKeyswitchingKey {
Standard(CudaLweKeyswitchKey<u64>),
KeySwitch32(CudaLweKeyswitchKey<u32>),
}
/// A structure containing the server public key.
///
/// The server key is generated by the client and is meant to be published: the client
/// sends it to the server so it can compute homomorphic circuits.
// #[derive(PartialEq, Serialize, Deserialize)]
pub struct CudaServerKey {
pub key_switching_key: CudaLweKeyswitchKey<u64>,
pub bootstrapping_key: CudaBootstrappingKey<u64>,
pub key_switching_key: CudaDynamicKeyswitchingKey,
pub bootstrapping_key: CudaBootstrappingKey<u64>, // the GGSW of the BSK
// Size of the message buffer
pub message_modulus: MessageModulus,
// Size of the carry buffer
@@ -180,7 +184,7 @@ impl CudaServerKey {
// Pack the keys in the server key set:
Self {
key_switching_key: d_key_switching_key,
key_switching_key: CudaDynamicKeyswitchingKey::Standard(d_key_switching_key),
bootstrapping_key: d_bootstrapping_key,
message_modulus: std_cks.parameters.message_modulus(),
carry_modulus: std_cks.parameters.carry_modulus(),
@@ -239,57 +243,108 @@ impl CudaServerKey {
max_noise_level,
} = cpu_key.key.clone();
// Generate a regular keyset and convert to the GPU
let CompressedAtomicPatternServerKey::Standard(std_key) = compressed_ap_server_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let ciphertext_modulus = compressed_ap_server_key.ciphertext_modulus();
match compressed_ap_server_key {
CompressedAtomicPatternServerKey::Standard(std_key) => {
let (key_switching_key, bootstrapping_key, pbs_order) = std_key.into_raw_parts();
let ciphertext_modulus = std_key.ciphertext_modulus();
let (key_switching_key, bootstrapping_key, pbs_order) = std_key.into_raw_parts();
let h_key_switching_key = key_switching_key.par_decompress_into_lwe_keyswitch_key();
let key_switching_key =
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);
let bootstrapping_key = match bootstrapping_key {
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::Classic{ bsk: h_bootstrap_key, modulus_switch_noise_reduction_key, } => {
let h_key_switching_key = key_switching_key.par_decompress_into_lwe_keyswitch_key();
let key_switching_key =
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);
let bootstrapping_key = match bootstrapping_key {
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::Classic{ bsk: h_bootstrap_key, modulus_switch_noise_reduction_key, } => {
let modulus_switch_noise_reduction_configuration = match modulus_switch_noise_reduction_key {
CompressedModulusSwitchConfiguration::Standard => None,
CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_modulus_switch_noise_reduction_key) => panic!("Drift noise reduction is not supported on GPU"),
CompressedModulusSwitchConfiguration::CenteredMeanNoiseReduction => Some(CudaModulusSwitchNoiseReductionConfiguration::Centered),
};
let modulus_switch_noise_reduction_configuration = match modulus_switch_noise_reduction_key {
CompressedModulusSwitchConfiguration::Standard => None,
CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_modulus_switch_noise_reduction_key) => panic!("Drift noise reduction is not supported on GPU"),
CompressedModulusSwitchConfiguration::CenteredMeanNoiseReduction => Some(CudaModulusSwitchNoiseReductionConfiguration::Centered),
let standard_bootstrapping_key = h_bootstrap_key.par_decompress_into_lwe_bootstrap_key();
let d_bootstrap_key =
CudaLweBootstrapKey::from_lwe_bootstrap_key(&standard_bootstrapping_key, modulus_switch_noise_reduction_configuration, streams);
CudaBootstrappingKey::Classic(d_bootstrap_key)
}
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::MultiBit {
seeded_bsk: bootstrapping_key,
deterministic_execution: _,
} => {
let standard_bootstrapping_key =
bootstrapping_key.par_decompress_into_lwe_multi_bit_bootstrap_key();
let d_bootstrap_key =
CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(
&standard_bootstrapping_key, streams);
CudaBootstrappingKey::MultiBit(d_bootstrap_key)
}
};
let standard_bootstrapping_key = h_bootstrap_key.par_decompress_into_lwe_bootstrap_key();
Self {
key_switching_key: CudaDynamicKeyswitchingKey::Standard(key_switching_key),
bootstrapping_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order,
}
}
CompressedAtomicPatternServerKey::KeySwitch32(ks32_key) => {
let key_switching_key = ks32_key.key_switching_key();
let bootstrapping_key = ks32_key.bootstrapping_key();
let d_bootstrap_key =
let h_key_switching_key = key_switching_key
.as_view()
.par_decompress_into_lwe_keyswitch_key();
let key_switching_key =
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);
let bootstrapping_key = match bootstrapping_key {
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::Classic{ bsk: h_bootstrap_key, modulus_switch_noise_reduction_key, } => {
let modulus_switch_noise_reduction_configuration = match modulus_switch_noise_reduction_key {
CompressedModulusSwitchConfiguration::Standard => None,
CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_modulus_switch_noise_reduction_key) => panic!("Drift noise reduction is not supported on GPU"),
CompressedModulusSwitchConfiguration::CenteredMeanNoiseReduction => Some(CudaModulusSwitchNoiseReductionConfiguration::Centered),
};
let standard_bootstrapping_key = h_bootstrap_key.as_view().par_decompress_into_lwe_bootstrap_key();
let d_bootstrap_key =
CudaLweBootstrapKey::from_lwe_bootstrap_key(&standard_bootstrapping_key, modulus_switch_noise_reduction_configuration, streams);
CudaBootstrappingKey::Classic(d_bootstrap_key)
}
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::MultiBit {
seeded_bsk: bootstrapping_key,
deterministic_execution: _,
} => {
let standard_bootstrapping_key =
bootstrapping_key.par_decompress_into_lwe_multi_bit_bootstrap_key();
CudaBootstrappingKey::Classic(d_bootstrap_key)
}
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::MultiBit {
seeded_bsk: bootstrapping_key,
deterministic_execution: _,
} => {
let standard_bootstrapping_key =
bootstrapping_key.as_view().par_decompress_into_lwe_multi_bit_bootstrap_key();
let d_bootstrap_key =
CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(
let d_bootstrap_key =
CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(
&standard_bootstrapping_key, streams);
CudaBootstrappingKey::MultiBit(d_bootstrap_key)
}
};
CudaBootstrappingKey::MultiBit(d_bootstrap_key)
}
};
Self {
key_switching_key,
bootstrapping_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order,
Self {
key_switching_key: CudaDynamicKeyswitchingKey::KeySwitch32(key_switching_key),
bootstrapping_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order: PBSOrder::KeyswitchBootstrap,
}
}
}
}

View File

@@ -1,7 +1,9 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{cuda_backend_unchecked_signed_abs_assign, PBSType};
impl CudaServerKey {
@@ -11,6 +13,10 @@ impl CudaServerKey {
{
let num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -18,19 +24,15 @@ impl CudaServerKey {
streams,
ct.as_mut(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
num_blocks,
@@ -44,19 +46,15 @@ impl CudaServerKey {
streams,
ct.as_mut(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
num_blocks,

View File

@@ -5,7 +5,9 @@ use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaSignedRadixCiphertext,
CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{
cuda_backend_add_and_propagate_single_carry_assign,
cuda_backend_get_add_and_propagate_single_carry_assign_size_on_gpu,
@@ -127,6 +129,10 @@ impl CudaServerKey {
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0,
ct_right.as_ref().d_blocks.lwe_ciphertext_count().0
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_full_propagate_assign_size_on_gpu(
@@ -134,8 +140,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -151,8 +157,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -181,8 +187,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
num_blocks,
@@ -200,8 +206,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
num_blocks,
@@ -331,6 +337,9 @@ impl CudaServerKey {
let radix_count_in_vec = ciphertexts.len();
let mut terms = CudaRadixCiphertext::from_radix_ciphertext_vec(ciphertexts, streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -341,16 +350,14 @@ impl CudaServerKey {
&mut terms,
reduce_degrees_for_single_carry_propagation,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
num_blocks.0 as u32,
@@ -367,16 +374,14 @@ impl CudaServerKey {
&mut terms,
reduce_degrees_for_single_carry_propagation,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
num_blocks.0 as u32,
@@ -694,6 +699,9 @@ impl CudaServerKey {
let aux_block: T = self.create_trivial_zero_radix(1, streams);
let in_carry: &CudaRadixCiphertext =
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -705,12 +713,12 @@ impl CudaServerKey {
carry_out.as_mut(),
in_carry,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
num_blocks,
@@ -731,12 +739,12 @@ impl CudaServerKey {
carry_out.as_mut(),
in_carry,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
num_blocks,

View File

@@ -2,7 +2,9 @@ use crate::core_crypto::gpu::{
check_valid_cuda_malloc, check_valid_cuda_malloc_assert_oom, CudaStreams,
};
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::{
@@ -271,6 +273,10 @@ impl CudaServerKey {
result.as_ref().d_blocks.lwe_ciphertext_count().0
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -283,14 +289,14 @@ impl CudaServerKey {
num_aes_inputs as u32,
sbox_parallelism as u32,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -308,14 +314,14 @@ impl CudaServerKey {
num_aes_inputs as u32,
sbox_parallelism as u32,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -334,6 +340,10 @@ impl CudaServerKey {
sbox_parallelism: usize,
streams: &CudaStreams,
) -> u64 {
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_aes_ctr_encrypt_size_on_gpu(
streams,
@@ -344,8 +354,8 @@ impl CudaServerKey {
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -362,8 +372,8 @@ impl CudaServerKey {
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -392,6 +402,10 @@ impl CudaServerKey {
key.as_ref().d_blocks.lwe_ciphertext_count().0
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -400,14 +414,14 @@ impl CudaServerKey {
expanded_keys.as_mut(),
key.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -421,14 +435,14 @@ impl CudaServerKey {
expanded_keys.as_mut(),
key.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -442,6 +456,10 @@ impl CudaServerKey {
}
pub fn get_key_expansion_size_on_gpu(&self, streams: &CudaStreams) -> u64 {
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_aes_key_expansion_size_on_gpu(
streams,
@@ -450,8 +468,8 @@ impl CudaServerKey {
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -466,8 +484,8 @@ impl CudaServerKey {
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,

View File

@@ -2,7 +2,9 @@ use crate::core_crypto::gpu::{
check_valid_cuda_malloc, check_valid_cuda_malloc_assert_oom, CudaStreams,
};
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::{
@@ -197,6 +199,10 @@ impl CudaServerKey {
result.as_ref().d_blocks.lwe_ciphertext_count().0
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -209,14 +215,14 @@ impl CudaServerKey {
num_aes_inputs as u32,
sbox_parallelism as u32,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -234,14 +240,14 @@ impl CudaServerKey {
num_aes_inputs as u32,
sbox_parallelism as u32,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -274,6 +280,10 @@ impl CudaServerKey {
key.as_ref().d_blocks.lwe_ciphertext_count().0
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -282,14 +292,14 @@ impl CudaServerKey {
expanded_keys.as_mut(),
key.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -303,14 +313,14 @@ impl CudaServerKey {
expanded_keys.as_mut(),
key.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -324,6 +334,10 @@ impl CudaServerKey {
}
pub fn get_key_expansion_256_size_on_gpu(&self, streams: &CudaStreams) -> u64 {
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_aes_key_expansion_256_size_on_gpu(
@@ -333,8 +347,8 @@ impl CudaServerKey {
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -350,8 +364,8 @@ impl CudaServerKey {
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,

View File

@@ -2,7 +2,7 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext};
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
use crate::integer::gpu::{
cuda_backend_boolean_bitnot_assign, cuda_backend_boolean_bitop_assign,
cuda_backend_get_bitop_size_on_gpu, cuda_backend_get_boolean_bitnot_size_on_gpu,
@@ -324,6 +324,10 @@ impl CudaServerKey {
ct_right.0.as_ref().d_blocks.lwe_ciphertext_count()
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -332,19 +336,15 @@ impl CudaServerKey {
ct_left.0.as_mut(),
ct_right.0.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
@@ -361,19 +361,15 @@ impl CudaServerKey {
ct_left.0.as_mut(),
ct_right.0.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
@@ -412,6 +408,10 @@ impl CudaServerKey {
is_unchecked: bool,
streams: &CudaStreams,
) {
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -420,15 +420,15 @@ impl CudaServerKey {
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
is_unchecked,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
d_bsk.output_lwe_dimension(),
d_bsk.input_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
PBSType::Classical,
@@ -442,15 +442,15 @@ impl CudaServerKey {
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
is_unchecked,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
d_multibit_bsk.output_lwe_dimension(),
d_multibit_bsk.input_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
PBSType::MultiBit,
@@ -534,6 +534,10 @@ impl CudaServerKey {
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -542,19 +546,15 @@ impl CudaServerKey {
ct_left.as_mut(),
ct_right.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
@@ -570,19 +570,15 @@ impl CudaServerKey {
ct_left.as_mut(),
ct_right.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
@@ -611,6 +607,10 @@ impl CudaServerKey {
ct_left.0.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.0.as_ref().d_blocks.lwe_ciphertext_count()
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let boolean_bitop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_boolean_bitop_size_on_gpu(
streams,
@@ -618,14 +618,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
@@ -642,14 +638,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
@@ -679,6 +671,10 @@ impl CudaServerKey {
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_full_propagate_assign_size_on_gpu(
@@ -686,8 +682,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -703,8 +699,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -734,14 +730,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
@@ -756,14 +748,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
@@ -1287,6 +1275,10 @@ impl CudaServerKey {
_ct: &CudaBooleanBlock,
streams: &CudaStreams,
) -> u64 {
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let boolean_bitnot_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_boolean_bitnot_size_on_gpu(
streams,
@@ -1294,14 +1286,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
false,
@@ -1317,14 +1305,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
false,
@@ -1344,6 +1328,10 @@ impl CudaServerKey {
ct: &T,
streams: &CudaStreams,
) -> u64 {
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = if ct.block_carries_are_empty() {
0
} else {
@@ -1354,8 +1342,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -1371,8 +1359,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,

View File

@@ -2,7 +2,7 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
use crate::integer::gpu::{
cuda_backend_get_cmux_size_on_gpu, cuda_backend_get_full_propagate_assign_size_on_gpu,
cuda_backend_unchecked_cmux, CudaServerKey, PBSType,
@@ -20,6 +20,10 @@ impl CudaServerKey {
let mut result: T = self
.create_trivial_zero_radix(true_ct.as_ref().d_blocks.lwe_ciphertext_count().0, stream);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -30,19 +34,15 @@ impl CudaServerKey {
true_ct.as_ref(),
false_ct.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -59,19 +59,15 @@ impl CudaServerKey {
true_ct.as_ref(),
false_ct.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -129,6 +125,10 @@ impl CudaServerKey {
true_ct.as_ref().d_blocks.lwe_ciphertext_count(),
false_ct.as_ref().d_blocks.lwe_ciphertext_count()
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_full_propagate_assign_size_on_gpu(
@@ -136,8 +136,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -153,8 +153,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -184,14 +184,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -205,14 +201,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,

View File

@@ -4,7 +4,7 @@ use crate::core_crypto::prelude::{LweBskGroupingFactor, LweCiphertextCount};
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::info::CudaRadixCiphertextInfo;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext};
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
use crate::integer::gpu::{
cuda_backend_get_comparison_size_on_gpu, cuda_backend_get_full_propagate_assign_size_on_gpu,
cuda_backend_unchecked_comparison, ComparisonType, CudaServerKey, PBSType,
@@ -45,6 +45,10 @@ impl CudaServerKey {
let mut result =
CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(block, ct_info));
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -54,19 +58,15 @@ impl CudaServerKey {
ct_left.as_ref(),
ct_right.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
@@ -83,19 +83,15 @@ impl CudaServerKey {
ct_left.as_ref(),
ct_right.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
@@ -226,6 +222,10 @@ impl CudaServerKey {
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_full_propagate_assign_size_on_gpu(
@@ -233,8 +233,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -250,8 +250,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -273,6 +273,9 @@ impl CudaServerKey {
};
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let comparison_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_comparison_size_on_gpu(
@@ -281,14 +284,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -305,14 +304,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -822,6 +817,10 @@ impl CudaServerKey {
let mut result = ct_left.duplicate(streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -831,19 +830,15 @@ impl CudaServerKey {
ct_left.as_ref(),
ct_right.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
ComparisonType::MAX,
@@ -860,19 +855,15 @@ impl CudaServerKey {
ct_left.as_ref(),
ct_right.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
ComparisonType::MAX,
@@ -902,6 +893,10 @@ impl CudaServerKey {
let mut result = ct_left.duplicate(streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -911,19 +906,15 @@ impl CudaServerKey {
ct_left.as_ref(),
ct_right.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
ComparisonType::MIN,
@@ -940,19 +931,15 @@ impl CudaServerKey {
ct_left.as_ref(),
ct_right.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
ComparisonType::MIN,

View File

@@ -1,7 +1,9 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{
cuda_backend_get_div_rem_size_on_gpu, cuda_backend_get_full_propagate_assign_size_on_gpu,
cuda_backend_unchecked_div_rem_assign, PBSType,
@@ -18,6 +20,10 @@ impl CudaServerKey {
) where
T: CudaIntegerRadixCiphertext,
{
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
// TODO add asserts from `unchecked_div_rem_parallelized`
let num_blocks = divisor.as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
unsafe {
@@ -31,19 +37,15 @@ impl CudaServerKey {
divisor.as_ref(),
T::IS_SIGNED,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
num_blocks,
@@ -61,19 +63,15 @@ impl CudaServerKey {
divisor.as_ref(),
T::IS_SIGNED,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
num_blocks,
@@ -227,6 +225,10 @@ impl CudaServerKey {
numerator.as_ref().d_blocks.lwe_ciphertext_count(),
divisor.as_ref().d_blocks.lwe_ciphertext_count()
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_full_propagate_assign_size_on_gpu(
@@ -234,8 +236,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -251,8 +253,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -283,14 +285,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -305,14 +303,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,

View File

@@ -4,7 +4,9 @@ use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{cuda_backend_count_of_consecutive_bits, cuda_backend_ilog2, PBSType};
use crate::integer::server_key::radix_parallel::ilog2::{BitValue, Direction};
@@ -34,6 +36,10 @@ impl CudaServerKey {
let mut result: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(counter_num_blocks, streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -42,12 +48,12 @@ impl CudaServerKey {
result.as_mut(),
ct.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -65,12 +71,12 @@ impl CudaServerKey {
result.as_mut(),
ct.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -190,6 +196,10 @@ impl CudaServerKey {
let mut result: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(counter_num_blocks, streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -201,12 +211,12 @@ impl CudaServerKey {
trivial_ct_2.as_ref(),
trivial_ct_m_minus_1_block.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
LweBskGroupingFactor(0),
@@ -228,12 +238,12 @@ impl CudaServerKey {
trivial_ct_2.as_ref(),
trivial_ct_m_minus_1_block.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
d_multibit_bsk.grouping_factor,

View File

@@ -14,7 +14,7 @@ use crate::integer::gpu::ciphertext::{
CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
use crate::integer::gpu::{
cuda_backend_apply_many_univariate_lut, cuda_backend_apply_univariate_lut,
cuda_backend_cast_to_signed, cuda_backend_cast_to_unsigned,
@@ -182,9 +182,13 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let lwe_size = match self.pbs_order {
PBSOrder::KeyswitchBootstrap => self.key_switching_key.input_key_lwe_size(),
PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_size(),
PBSOrder::KeyswitchBootstrap => computing_ks_key.input_key_lwe_size(),
PBSOrder::BootstrapKeyswitch => computing_ks_key.output_key_lwe_size(),
};
let decomposer =
@@ -235,6 +239,10 @@ impl CudaServerKey {
let in_carry: &CudaRadixCiphertext =
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -244,12 +252,12 @@ impl CudaServerKey {
carry_out.as_mut(),
in_carry,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
num_blocks,
@@ -269,12 +277,12 @@ impl CudaServerKey {
carry_out.as_mut(),
in_carry,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
num_blocks,
@@ -299,6 +307,10 @@ impl CudaServerKey {
) {
let ciphertext = ct.as_mut();
let num_blocks = ciphertext.d_blocks.lwe_ciphertext_count().0 as u32;
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -306,12 +318,12 @@ impl CudaServerKey {
streams,
ciphertext,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
num_blocks,
@@ -327,12 +339,12 @@ impl CudaServerKey {
streams,
ciphertext,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
num_blocks,
@@ -689,6 +701,10 @@ impl CudaServerKey {
let mut output_degrees = vec![0_u64; num_output_blocks];
let mut output_noise_levels = vec![0_u64; num_output_blocks];
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let num_ct_blocks = block_range.len() as u32;
unsafe {
match &self.bootstrapping_key {
@@ -702,14 +718,12 @@ impl CudaServerKey {
lut.acc.as_ref(),
lut.degree.0,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
&computing_ks_key.d_vec,
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
num_ct_blocks,
@@ -730,14 +744,12 @@ impl CudaServerKey {
lut.acc.as_ref(),
lut.degree.0,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
&computing_ks_key.d_vec,
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
num_ct_blocks,
@@ -860,6 +872,9 @@ impl CudaServerKey {
.unwrap();
let mut output_degrees = vec![0_u64; num_ct_blocks * function_count];
let mut output_noise_levels = vec![0_u64; num_ct_blocks * function_count];
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -873,14 +888,12 @@ impl CudaServerKey {
lut.acc.as_ref(),
lut.input_max_degree.0,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
&computing_ks_key.d_vec,
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
num_ct_blocks as u32,
@@ -903,14 +916,12 @@ impl CudaServerKey {
lut.acc.as_ref(),
lut.input_max_degree.0,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
&computing_ks_key.d_vec,
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
num_ct_blocks as u32,
@@ -1001,6 +1012,9 @@ impl CudaServerKey {
self.create_trivial_zero_radix(target_num_blocks, streams);
let requires_full_propagate = !source.block_carries_are_empty();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -1013,17 +1027,13 @@ impl CudaServerKey {
requires_full_propagate,
target_num_blocks as u32,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -1040,17 +1050,13 @@ impl CudaServerKey {
requires_full_propagate,
target_num_blocks as u32,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -1114,6 +1120,10 @@ impl CudaServerKey {
let mut output_ct: CudaSignedRadixCiphertext =
self.create_trivial_zero_radix(target_num_blocks, streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -1123,16 +1133,14 @@ impl CudaServerKey {
source.as_ref(),
T::IS_SIGNED,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -1147,16 +1155,14 @@ impl CudaServerKey {
source.as_ref(),
T::IS_SIGNED,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -1202,6 +1208,10 @@ impl CudaServerKey {
d_multibit_bsk.polynomial_size(),
),
};
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &d_bootstrapping_key {
CudaBootstrappingKey::Classic(bsk) => {
@@ -1212,16 +1222,14 @@ impl CudaServerKey {
&mut output_noise_levels,
&input_slice,
&bsk.d_vec,
&self.key_switching_key.d_vec,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
&computing_ks_key.d_vec,
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
bsk.glwe_dimension,
bsk.polynomial_size,
input_glwe_dimension,
input_polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
bsk.decomp_level_count,
bsk.decomp_base_log,
num_output_blocks as u32,
@@ -1241,16 +1249,14 @@ impl CudaServerKey {
&mut output_noise_levels,
&input_slice,
&mb_bsk.d_vec,
&self.key_switching_key.d_vec,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
&computing_ks_key.d_vec,
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
mb_bsk.glwe_dimension,
mb_bsk.polynomial_size,
input_glwe_dimension,
input_polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
mb_bsk.decomp_level_count,
mb_bsk.decomp_base_log,
num_output_blocks as u32,

View File

@@ -1,7 +1,9 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{
cuda_backend_get_full_propagate_assign_size_on_gpu, cuda_backend_get_mul_size_on_gpu,
cuda_backend_unchecked_mul_assign, PBSType,
@@ -74,6 +76,10 @@ impl CudaServerKey {
let is_boolean_left = ct_left.holds_boolean_value();
let is_boolean_right = ct_right.holds_boolean_value();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -84,7 +90,7 @@ impl CudaServerKey {
ct_right.as_ref(),
is_boolean_right,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension(),
@@ -92,8 +98,8 @@ impl CudaServerKey {
d_bsk.polynomial_size(),
d_bsk.decomp_base_log(),
d_bsk.decomp_level_count(),
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
num_blocks,
PBSType::Classical,
LweBskGroupingFactor(0),
@@ -108,7 +114,7 @@ impl CudaServerKey {
ct_right.as_ref(),
is_boolean_right,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension(),
@@ -116,8 +122,8 @@ impl CudaServerKey {
d_multibit_bsk.polynomial_size(),
d_multibit_bsk.decomp_base_log(),
d_multibit_bsk.decomp_level_count(),
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
num_blocks,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
@@ -233,6 +239,10 @@ impl CudaServerKey {
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_full_propagate_assign_size_on_gpu(
@@ -240,8 +250,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -257,8 +267,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -291,14 +301,12 @@ impl CudaServerKey {
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
d_bsk.polynomial_size,
d_bsk.decomp_base_log,
d_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
@@ -311,14 +319,12 @@ impl CudaServerKey {
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
d_multibit_bsk.polynomial_size,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,

View File

@@ -3,7 +3,9 @@ use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaSignedRadixCiphertext,
CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use itertools::Itertools;
use crate::core_crypto::commons::generators::DeterministicSeeder;
@@ -351,6 +353,9 @@ impl CudaServerKey {
}
let message_bits_count = self.message_modulus.0.ilog2();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -364,8 +369,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -387,8 +392,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
d_bsk.grouping_factor,
@@ -484,6 +489,10 @@ impl CudaServerKey {
let mut result: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(num_blocks_output as usize, streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -496,12 +505,12 @@ impl CudaServerKey {
has_at_least_one_set.as_slice(),
num_input_random_bits as u32,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -523,12 +532,12 @@ impl CudaServerKey {
has_at_least_one_set.as_slice(),
num_input_random_bits as u32,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
d_bsk.grouping_factor,
@@ -553,6 +562,9 @@ impl CudaServerKey {
streams: &CudaStreams,
) -> u64 {
let message_bits = self.message_modulus.0.ilog2();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_grouped_oprf_size_on_gpu(
@@ -561,8 +573,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -579,8 +591,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
d_bsk.grouping_factor,

View File

@@ -1,7 +1,7 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
use crate::integer::gpu::{
cuda_backend_get_full_propagate_assign_size_on_gpu, cuda_backend_get_rotate_left_size_on_gpu,
cuda_backend_get_rotate_right_size_on_gpu, cuda_backend_unchecked_rotate_left_assign,
@@ -19,6 +19,9 @@ impl CudaServerKey {
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let is_signed = T::IS_SIGNED;
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -28,19 +31,15 @@ impl CudaServerKey {
ct.as_mut(),
rotate.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -56,19 +55,15 @@ impl CudaServerKey {
ct.as_mut(),
rotate.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -106,6 +101,9 @@ impl CudaServerKey {
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let is_signed = T::IS_SIGNED;
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -115,19 +113,15 @@ impl CudaServerKey {
ct.as_mut(),
rotate.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -143,19 +137,15 @@ impl CudaServerKey {
ct.as_mut(),
rotate.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -437,6 +427,10 @@ impl CudaServerKey {
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_full_propagate_assign_size_on_gpu(
@@ -444,8 +438,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -461,8 +455,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -492,14 +486,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -515,14 +505,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -550,6 +536,10 @@ impl CudaServerKey {
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_full_propagate_assign_size_on_gpu(
@@ -557,8 +547,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -574,8 +564,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -605,14 +595,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -628,14 +614,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,

View File

@@ -6,7 +6,9 @@ use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{
cuda_backend_get_full_propagate_assign_size_on_gpu,
cuda_backend_get_propagate_single_carry_assign_size_on_gpu,
@@ -179,6 +181,10 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = if ct.block_carries_are_empty() {
0
} else {
@@ -189,8 +195,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -206,8 +212,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -228,8 +234,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
num_blocks,
@@ -247,8 +253,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
num_blocks,

View File

@@ -3,7 +3,7 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
use crate::integer::gpu::{
cuda_backend_get_full_propagate_assign_size_on_gpu, cuda_backend_get_scalar_bitop_size_on_gpu,
cuda_backend_unchecked_scalar_bitop_assign, BitOpType, CudaServerKey, PBSType,
@@ -29,6 +29,10 @@ impl CudaServerKey {
.collect::<Vec<_>>();
let clear_blocks = unsafe { CudaVec::from_cpu_async(&h_clear_blocks, streams, 0) };
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -38,19 +42,15 @@ impl CudaServerKey {
&clear_blocks,
&h_clear_blocks,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
@@ -67,19 +67,15 @@ impl CudaServerKey {
&clear_blocks,
&h_clear_blocks,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
@@ -241,6 +237,9 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = if ct.block_carries_are_empty() {
0
@@ -252,8 +251,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -269,8 +268,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -291,14 +290,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
@@ -314,14 +309,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,

View File

@@ -6,7 +6,9 @@ use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::info::CudaRadixCiphertextInfo;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{
cuda_backend_unchecked_are_all_comparisons_block_true,
cuda_backend_unchecked_is_at_least_one_comparisons_block_true,
@@ -167,6 +169,9 @@ impl CudaServerKey {
let mut result =
CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(block, ct_info));
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -178,19 +183,15 @@ impl CudaServerKey {
&d_scalar_blocks,
&scalar_blocks,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
scalar_blocks.len() as u32,
@@ -209,19 +210,15 @@ impl CudaServerKey {
&d_scalar_blocks,
&scalar_blocks,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
scalar_blocks.len() as u32,
@@ -320,6 +317,9 @@ impl CudaServerKey {
unsafe { CudaVec::from_cpu_async(&scalar_blocks, streams, 0) };
let mut result = ct.duplicate(streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -331,19 +331,15 @@ impl CudaServerKey {
&d_scalar_blocks,
&scalar_blocks,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
scalar_blocks.len() as u32,
@@ -362,19 +358,15 @@ impl CudaServerKey {
&d_scalar_blocks,
&scalar_blocks,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
scalar_blocks.len() as u32,
@@ -400,6 +392,10 @@ impl CudaServerKey {
{
let ct_res: T = self.create_trivial_radix(0, 1, streams);
let mut boolean_res = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -408,19 +404,15 @@ impl CudaServerKey {
boolean_res.as_mut().as_mut(),
ct.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -434,19 +426,15 @@ impl CudaServerKey {
boolean_res.as_mut().as_mut(),
ct.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -469,6 +457,10 @@ impl CudaServerKey {
{
let ct_res: T = self.create_trivial_radix(0, 1, streams);
let mut boolean_res = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -477,19 +469,15 @@ impl CudaServerKey {
boolean_res.as_mut().as_mut(),
ct.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -503,19 +491,15 @@ impl CudaServerKey {
boolean_res.as_mut().as_mut(),
ct.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,

View File

@@ -4,7 +4,7 @@ use crate::integer::block_decomposition::DecomposableInto;
use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
use crate::integer::gpu::{
cuda_backend_get_full_propagate_assign_size_on_gpu,
cuda_backend_get_scalar_div_rem_size_on_gpu, cuda_backend_get_scalar_div_size_on_gpu,
@@ -85,6 +85,9 @@ impl CudaServerKey {
);
let mut quotient = numerator.duplicate(streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -93,15 +96,15 @@ impl CudaServerKey {
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -114,15 +117,15 @@ impl CudaServerKey {
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -220,6 +223,9 @@ impl CudaServerKey {
numerator.as_ref().d_blocks.lwe_ciphertext_count().0,
streams,
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -229,15 +235,15 @@ impl CudaServerKey {
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -251,15 +257,15 @@ impl CudaServerKey {
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -420,6 +426,9 @@ impl CudaServerKey {
);
let mut quotient: CudaSignedRadixCiphertext = numerator.duplicate(streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -428,15 +437,15 @@ impl CudaServerKey {
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -449,15 +458,15 @@ impl CudaServerKey {
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -555,6 +564,9 @@ impl CudaServerKey {
numerator.as_ref().d_blocks.lwe_ciphertext_count().0,
streams,
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -564,15 +576,15 @@ impl CudaServerKey {
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -586,15 +598,15 @@ impl CudaServerKey {
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -758,6 +770,10 @@ encrypted bits: {numerator_bits}, scalar bits: {}
Scalar::BITS
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = if numerator.block_carries_are_empty() {
0
} else {
@@ -768,8 +784,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -785,8 +801,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -808,8 +824,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -826,8 +842,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -864,6 +880,10 @@ encrypted bits: {numerator_bits}, scalar bits: {}
Scalar::BITS
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_scalar_div_rem_size_on_gpu(
streams,
@@ -873,8 +893,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -891,8 +911,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -940,6 +960,9 @@ encrypted bits: {numerator_bits}, scalar bits: {}
"The scalar divisor type must have a number of bits that is\
>= to the number of bits encrypted in the ciphertext"
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_signed_scalar_div_size_on_gpu(
@@ -952,8 +975,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
d_bsk.input_lwe_dimension,
d_bsk.decomp_base_log,
d_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
LweBskGroupingFactor(0),
num_blocks,
PBSType::Classical,
@@ -970,8 +993,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
d_multibit_bsk.input_lwe_dimension,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
d_multibit_bsk.grouping_factor,
num_blocks,
PBSType::MultiBit,
@@ -1001,6 +1024,9 @@ encrypted bits: {numerator_bits}, scalar bits: {}
"The scalar divisor type must have a number of bits that is\
>= to the number of bits encrypted in the ciphertext"
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -1012,8 +1038,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -1031,8 +1057,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,

View File

@@ -2,7 +2,9 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{
cuda_backend_get_full_propagate_assign_size_on_gpu, cuda_backend_get_scalar_mul_size_on_gpu,
cuda_backend_unchecked_scalar_mul, PBSType,
@@ -107,6 +109,9 @@ impl CudaServerKey {
if decomposed_scalar.is_empty() {
return;
}
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -117,18 +122,16 @@ impl CudaServerKey {
decomposed_scalar.as_slice(),
has_at_least_one_set.as_slice(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
d_bsk.decomp_base_log,
d_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
decomposed_scalar.len() as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
@@ -142,18 +145,16 @@ impl CudaServerKey {
decomposed_scalar.as_slice(),
has_at_least_one_set.as_slice(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
decomposed_scalar.len() as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
@@ -246,6 +247,9 @@ impl CudaServerKey {
// than multiplying
return self.get_scalar_left_shift_size_on_gpu(ct, streams);
}
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = if ct.block_carries_are_empty() {
0
@@ -257,8 +261,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -274,8 +278,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -302,13 +306,11 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
d_bsk.decomp_base_log,
d_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
@@ -322,13 +324,11 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,

View File

@@ -1,7 +1,7 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::{CastFrom, LweBskGroupingFactor};
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
use crate::integer::gpu::{
cuda_backend_get_full_propagate_assign_size_on_gpu,
cuda_backend_get_scalar_rotate_left_size_on_gpu,
@@ -38,6 +38,10 @@ impl CudaServerKey {
u32: CastFrom<Scalar>,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -46,19 +50,15 @@ impl CudaServerKey {
ct.as_mut(),
u32::cast_from(n),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -73,19 +73,15 @@ impl CudaServerKey {
ct.as_mut(),
u32::cast_from(n),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -125,6 +121,10 @@ impl CudaServerKey {
u32: CastFrom<Scalar>,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -133,19 +133,15 @@ impl CudaServerKey {
ct.as_mut(),
u32::cast_from(n),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -160,19 +156,15 @@ impl CudaServerKey {
ct.as_mut(),
u32::cast_from(n),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -237,6 +229,9 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = if ct.block_carries_are_empty() {
0
@@ -248,8 +243,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -265,8 +260,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -286,14 +281,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -309,14 +300,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -334,6 +321,9 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = if ct.block_carries_are_empty() {
0
@@ -345,8 +335,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -362,8 +352,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -382,14 +372,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -403,14 +389,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,

View File

@@ -1,7 +1,7 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::{CastFrom, LweBskGroupingFactor};
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
use crate::integer::gpu::{
cuda_backend_get_full_propagate_assign_size_on_gpu,
cuda_backend_get_scalar_arithmetic_right_shift_size_on_gpu,
@@ -76,6 +76,9 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -85,19 +88,15 @@ impl CudaServerKey {
ct.as_mut(),
u32::cast_from(shift),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -112,19 +111,15 @@ impl CudaServerKey {
ct.as_mut(),
u32::cast_from(shift),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -200,6 +195,9 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
if T::IS_SIGNED {
@@ -210,19 +208,15 @@ impl CudaServerKey {
ct.as_mut(),
u32::cast_from(shift),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -236,19 +230,15 @@ impl CudaServerKey {
ct.as_mut(),
u32::cast_from(shift),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -265,19 +255,15 @@ impl CudaServerKey {
ct.as_mut(),
u32::cast_from(shift),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -292,19 +278,15 @@ impl CudaServerKey {
ct.as_mut(),
u32::cast_from(shift),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -451,10 +433,16 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = if ct.block_carries_are_empty() {
0
} else {
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key
else {
panic!("Only the standard atomic pattern is supported on GPU")
};
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_full_propagate_assign_size_on_gpu(
@@ -462,8 +450,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -479,8 +467,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -492,6 +480,7 @@ impl CudaServerKey {
}
}
};
let scalar_shift_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_scalar_left_shift_size_on_gpu(
streams,
@@ -499,14 +488,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -521,14 +506,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -546,6 +527,9 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = if ct.block_carries_are_empty() {
0
@@ -557,8 +541,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -574,8 +558,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -596,14 +580,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -619,14 +599,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -645,14 +621,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -668,14 +640,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,

View File

@@ -1,7 +1,7 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
use crate::integer::gpu::{
cuda_backend_get_full_propagate_assign_size_on_gpu, cuda_backend_get_left_shift_size_on_gpu,
cuda_backend_get_right_shift_size_on_gpu, cuda_backend_unchecked_left_shift_assign,
@@ -19,6 +19,9 @@ impl CudaServerKey {
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let is_signed = T::IS_SIGNED;
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -28,19 +31,15 @@ impl CudaServerKey {
ct.as_mut(),
shift.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -56,19 +55,15 @@ impl CudaServerKey {
ct.as_mut(),
shift.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -106,6 +101,9 @@ impl CudaServerKey {
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let is_signed = T::IS_SIGNED;
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -115,19 +113,15 @@ impl CudaServerKey {
ct.as_mut(),
shift.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -143,19 +137,15 @@ impl CudaServerKey {
ct.as_mut(),
shift.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -435,6 +425,10 @@ impl CudaServerKey {
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_full_propagate_assign_size_on_gpu(
@@ -442,8 +436,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -459,8 +453,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -482,6 +476,9 @@ impl CudaServerKey {
};
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let shift_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_left_shift_size_on_gpu(
@@ -490,14 +487,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -513,14 +506,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -548,6 +537,10 @@ impl CudaServerKey {
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let full_prop_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_full_propagate_assign_size_on_gpu(
@@ -555,8 +548,8 @@ impl CudaServerKey {
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
self.message_modulus,
@@ -572,8 +565,8 @@ impl CudaServerKey {
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
self.message_modulus,
@@ -595,6 +588,9 @@ impl CudaServerKey {
};
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let shift_mem = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_right_shift_size_on_gpu(
@@ -603,14 +599,10 @@ impl CudaServerKey {
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
@@ -626,14 +618,10 @@ impl CudaServerKey {
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,

View File

@@ -4,7 +4,7 @@ use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaSignedRadixCiphertext,
CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::server_key::CudaServerKey;
use crate::integer::gpu::server_key::{CudaDynamicKeyswitchingKey, CudaServerKey};
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{
@@ -296,6 +296,9 @@ impl CudaServerKey {
let aux_block: CudaUnsignedRadixCiphertext = self.create_trivial_zero_radix(1, stream);
let in_carry_dvec =
INPUT_BORROW.map_or_else(|| aux_block.as_ref(), |block| block.as_ref().as_ref());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -307,12 +310,12 @@ impl CudaServerKey {
overflow_block.as_mut(),
in_carry_dvec,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
@@ -332,12 +335,12 @@ impl CudaServerKey {
overflow_block.as_mut(),
in_carry_dvec,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
@@ -374,6 +377,9 @@ impl CudaServerKey {
let aux_block: T = self.create_trivial_zero_radix(1, streams);
let in_carry: &CudaRadixCiphertext =
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -385,12 +391,12 @@ impl CudaServerKey {
carry_out.as_mut(),
in_carry,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
num_blocks,
@@ -411,12 +417,12 @@ impl CudaServerKey {
carry_out.as_mut(),
in_carry,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
num_blocks,

View File

@@ -18,7 +18,9 @@ use crate::core_crypto::prelude::*;
use crate::integer::gpu::ciphertext::info::CudaBlockInfo;
use crate::integer::gpu::ciphertext::CudaRadixCiphertext;
use crate::integer::gpu::server_key::radix::{CudaNoiseSquashingKey, CudaRadixCiphertextInfo};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{
cuda_centered_modulus_switch_64, unchecked_small_scalar_mul_integer_async, CudaStreams,
};
@@ -417,10 +419,14 @@ impl AllocateLweKeyswitchResult for CudaServerKey {
&self,
side_resources: &mut Self::SideResources,
) -> Self::Output {
let output_lwe_dimension = self
.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension();
let output_lwe_dimension = match &self.key_switching_key {
CudaDynamicKeyswitchingKey::Standard(std_key) => {
std_key.output_key_lwe_size().to_lwe_dimension()
}
CudaDynamicKeyswitchingKey::KeySwitch32(ks32_key) => {
ks32_key.output_key_lwe_size().to_lwe_dimension()
}
};
let lwe_ciphertext_count = LweCiphertextCount(1);
let ciphertext_modulus = self.ciphertext_modulus;
@@ -444,12 +450,39 @@ impl LweKeyswitch<CudaDynLwe, CudaDynLwe> for CudaServerKey {
side_resources: &mut Self::SideResources,
) {
match (input, output) {
(CudaDynLwe::U64(input_cuda_lwe), CudaDynLwe::U64(output_cuda_lwe)) => {
(CudaDynLwe::U64(input_cuda_lwe), CudaDynLwe::U32(output_cuda_lwe)) => {
let CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) =
&self.key_switching_key
else {
panic!("Expecting 32b KSK in Cuda noise simulation tests when LWE is 32b");
};
let input_indexes = CudaVec::new(1, &side_resources.streams, 0);
let output_indexes = CudaVec::new(1, &side_resources.streams, 0);
cuda_keyswitch_lwe_ciphertext(
&self.key_switching_key,
computing_ks_key,
input_cuda_lwe,
output_cuda_lwe,
&input_indexes,
&output_indexes,
false,
&side_resources.streams,
false,
);
}
(CudaDynLwe::U64(input_cuda_lwe), CudaDynLwe::U64(output_cuda_lwe)) => {
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) =
&self.key_switching_key
else {
panic!("Expecting 64b KSK in Cuda noise simulation tests when LWE is 64b");
};
let input_indexes = CudaVec::new(1, &side_resources.streams, 0);
let output_indexes = CudaVec::new(1, &side_resources.streams, 0);
cuda_keyswitch_lwe_ciphertext(
computing_ks_key,
input_cuda_lwe,
output_cuda_lwe,
&input_indexes,

View File

@@ -2,7 +2,9 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{
cuda_backend_unchecked_all_eq_slices, cuda_backend_unchecked_contains_sub_slice, PBSType,
};
@@ -56,6 +58,10 @@ impl CudaServerKey {
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -65,19 +71,15 @@ impl CudaServerKey {
lhs,
rhs,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -92,19 +94,15 @@ impl CudaServerKey {
lhs,
rhs,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -269,6 +267,10 @@ impl CudaServerKey {
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -278,19 +280,15 @@ impl CudaServerKey {
lhs,
rhs,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -305,19 +303,15 @@ impl CudaServerKey {
lhs,
rhs,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,

View File

@@ -3,7 +3,9 @@ use crate::core_crypto::prelude::{LweBskGroupingFactor, UnsignedInteger};
use crate::integer::block_decomposition::DecomposableInto;
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{
cuda_backend_get_unchecked_match_value_or_size_on_gpu,
cuda_backend_get_unchecked_match_value_size_on_gpu, cuda_backend_unchecked_contains,
@@ -52,6 +54,10 @@ impl CudaServerKey {
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams),
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -64,17 +70,13 @@ impl CudaServerKey {
self.message_modulus,
self.carry_modulus,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -92,17 +94,13 @@ impl CudaServerKey {
self.message_modulus,
self.carry_modulus,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -130,6 +128,10 @@ impl CudaServerKey {
return 0;
}
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_unchecked_match_value_size_on_gpu(
@@ -138,14 +140,10 @@ impl CudaServerKey {
matches,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -162,14 +160,10 @@ impl CudaServerKey {
matches,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -296,6 +290,10 @@ impl CudaServerKey {
let mut result: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(final_num_blocks, streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -308,17 +306,13 @@ impl CudaServerKey {
self.message_modulus,
self.carry_modulus,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -336,17 +330,13 @@ impl CudaServerKey {
self.message_modulus,
self.carry_modulus,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -375,6 +365,10 @@ impl CudaServerKey {
return 0;
}
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_unchecked_match_value_or_size_on_gpu(
@@ -384,14 +378,10 @@ impl CudaServerKey {
or_value,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
@@ -409,14 +399,10 @@ impl CudaServerKey {
or_value,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
@@ -519,6 +505,9 @@ impl CudaServerKey {
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams)
.into_inner(),
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -529,19 +518,15 @@ impl CudaServerKey {
cts,
value.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -556,19 +541,15 @@ impl CudaServerKey {
cts,
value.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -675,6 +656,9 @@ impl CudaServerKey {
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams)
.into_inner(),
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -685,19 +669,15 @@ impl CudaServerKey {
cts,
clear,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -712,19 +692,15 @@ impl CudaServerKey {
cts,
clear,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -822,6 +798,9 @@ impl CudaServerKey {
let ct_res: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams);
let mut boolean_res = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -832,19 +811,15 @@ impl CudaServerKey {
ct.as_ref(),
clears,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -859,19 +834,15 @@ impl CudaServerKey {
ct.as_ref(),
clears,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -982,6 +953,9 @@ impl CudaServerKey {
let trivial_bool =
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -993,19 +967,15 @@ impl CudaServerKey {
ct.as_ref(),
clears,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -1021,19 +991,15 @@ impl CudaServerKey {
ct.as_ref(),
clears,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -1154,6 +1120,9 @@ impl CudaServerKey {
let trivial_bool =
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -1165,19 +1134,15 @@ impl CudaServerKey {
ct.as_ref(),
clears,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -1193,19 +1158,15 @@ impl CudaServerKey {
ct.as_ref(),
clears,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -1313,6 +1274,9 @@ impl CudaServerKey {
let trivial_bool: CudaUnsignedRadixCiphertext = self.create_trivial_zero_radix(1, streams);
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
@@ -1324,19 +1288,15 @@ impl CudaServerKey {
cts,
value.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -1352,19 +1312,15 @@ impl CudaServerKey {
cts,
value.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -1501,6 +1457,10 @@ impl CudaServerKey {
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -1511,19 +1471,15 @@ impl CudaServerKey {
cts,
clear,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -1539,19 +1495,15 @@ impl CudaServerKey {
cts,
clear,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -1677,6 +1629,10 @@ impl CudaServerKey {
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -1687,19 +1643,15 @@ impl CudaServerKey {
cts,
clear,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -1715,19 +1667,15 @@ impl CudaServerKey {
cts,
clear,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
@@ -1851,6 +1799,10 @@ impl CudaServerKey {
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -1861,19 +1813,15 @@ impl CudaServerKey {
cts,
value.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
@@ -1889,19 +1837,15 @@ impl CudaServerKey {
cts,
value.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,