mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix(gpu): erc20 classical ks32
This commit is contained in:
@@ -1044,7 +1044,7 @@ template <typename Torus, typename KSTorus> struct unsigned_int_div_rem_memory {
|
||||
};
|
||||
|
||||
int_radix_lut<Torus, KSTorus> *luts[2] = {message_extract_lut_1,
|
||||
message_extract_lut_2};
|
||||
message_extract_lut_2};
|
||||
auto active_streams =
|
||||
streams.active_gpu_subset(num_blocks, params.pbs_type);
|
||||
for (int j = 0; j < 2; j++) {
|
||||
@@ -1134,7 +1134,8 @@ template <typename Torus, typename KSTorus> struct unsigned_int_div_rem_memory {
|
||||
zero_out_if_overflow_happened[1]->broadcast_lut(active_streams);
|
||||
|
||||
// merge_overflow_flags_luts
|
||||
merge_overflow_flags_luts = new int_radix_lut<Torus, KSTorus> *[num_bits_in_message];
|
||||
merge_overflow_flags_luts =
|
||||
new int_radix_lut<Torus, KSTorus> *[num_bits_in_message];
|
||||
auto active_gpu_count_for_bits =
|
||||
streams.active_gpu_subset(1, params.pbs_type);
|
||||
for (int i = 0; i < num_bits_in_message; i++) {
|
||||
|
||||
@@ -274,6 +274,16 @@ uint64_t scratch_cuda_comparison_64(
|
||||
bool is_signed, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
uint64_t scratch_cuda_comparison_64_ks32(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, COMPARISON_TYPE op_type,
|
||||
bool is_signed, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_comparison_ciphertext_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_1,
|
||||
@@ -281,6 +291,12 @@ void cuda_comparison_ciphertext_64(CudaStreamsFFI streams,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cuda_comparison_ciphertext_64_ks32(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_1,
|
||||
CudaRadixCiphertextFFI const *lwe_array_2, int8_t *mem_ptr,
|
||||
void *const *bsks, void *const *ksks);
|
||||
|
||||
void cuda_scalar_comparison_ciphertext_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in, void const *scalar_blocks,
|
||||
@@ -323,6 +339,20 @@ void cuda_boolean_bitnot_ciphertext_64(CudaStreamsFFI streams,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
uint64_t scratch_cuda_boolean_bitnot_64_ks32(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
|
||||
uint32_t lwe_ciphertext_count, bool is_unchecked, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_boolean_bitnot_ciphertext_64_ks32(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *lwe_array,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cleanup_cuda_boolean_bitnot(CudaStreamsFFI streams, int8_t **mem_ptr_void);
|
||||
|
||||
void cuda_bitnot_ciphertext_64(CudaStreamsFFI streams,
|
||||
@@ -340,6 +370,15 @@ uint64_t scratch_cuda_bitop_64(
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, BITOP_TYPE op_type,
|
||||
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
uint64_t scratch_cuda_bitop_64_ks32(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, BITOP_TYPE op_type,
|
||||
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_scalar_bitop_ciphertext_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_input, void const *clear_blocks,
|
||||
@@ -353,6 +392,13 @@ void cuda_bitop_ciphertext_64(CudaStreamsFFI streams,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cuda_bitop_ciphertext_64_ks32(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_1,
|
||||
CudaRadixCiphertextFFI const *lwe_array_2,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cleanup_cuda_integer_bitop(CudaStreamsFFI streams, int8_t **mem_ptr_void);
|
||||
|
||||
uint64_t scratch_cuda_cmux_64(CudaStreamsFFI streams, int8_t **mem_ptr,
|
||||
@@ -366,6 +412,15 @@ uint64_t scratch_cuda_cmux_64(CudaStreamsFFI streams, int8_t **mem_ptr,
|
||||
PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
uint64_t scratch_cuda_cmux_64_ks32(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_cmux_ciphertext_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_condition,
|
||||
@@ -374,6 +429,14 @@ void cuda_cmux_ciphertext_64(CudaStreamsFFI streams,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cuda_cmux_ciphertext_64_ks32(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_condition,
|
||||
CudaRadixCiphertextFFI const *lwe_array_true,
|
||||
CudaRadixCiphertextFFI const *lwe_array_false,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cleanup_cuda_cmux(CudaStreamsFFI streams, int8_t **mem_ptr_void);
|
||||
|
||||
uint64_t scratch_cuda_scalar_rotate_64(
|
||||
@@ -452,6 +515,15 @@ uint64_t scratch_cuda_integer_overflowing_sub_64_inplace(
|
||||
PBS_TYPE pbs_type, uint32_t compute_overflow, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
uint64_t scratch_cuda_integer_overflowing_sub_64_ks32_inplace(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
|
||||
PBS_TYPE pbs_type, uint32_t compute_overflow, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_integer_overflowing_sub_64_inplace(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
|
||||
const CudaRadixCiphertextFFI *rhs_array,
|
||||
@@ -460,6 +532,14 @@ void cuda_integer_overflowing_sub_64_inplace(
|
||||
void *const *bsks, void *const *ksks, uint32_t compute_overflow,
|
||||
uint32_t uses_input_borrow);
|
||||
|
||||
void cuda_integer_overflowing_sub_64_ks32_inplace(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
|
||||
const CudaRadixCiphertextFFI *rhs_array,
|
||||
CudaRadixCiphertextFFI *overflow_block,
|
||||
const CudaRadixCiphertextFFI *input_borrow, int8_t *mem_ptr,
|
||||
void *const *bsks, void *const *ksks, uint32_t compute_overflow,
|
||||
uint32_t uses_input_borrow);
|
||||
|
||||
void cleanup_cuda_integer_overflowing_sub(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
|
||||
@@ -844,12 +924,29 @@ uint64_t scratch_cuda_cast_to_unsigned_64(
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
uint64_t scratch_cuda_cast_to_unsigned_64_ks32(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed,
|
||||
bool requires_full_propagate, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI *input, int8_t *mem_ptr,
|
||||
uint32_t target_num_blocks, bool input_is_signed,
|
||||
void *const *bsks, void *const *ksks);
|
||||
|
||||
void cuda_cast_to_unsigned_64_ks32(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI *input,
|
||||
int8_t *mem_ptr, uint32_t target_num_blocks,
|
||||
bool input_is_signed, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ __host__ void vectorized_aes_256_encrypt_inplace(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *all_states_bitsliced,
|
||||
CudaRadixCiphertextFFI const *round_keys, uint32_t num_aes_inputs,
|
||||
int_aes_encrypt_buffer<Torus, KSTorus> *mem, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
constexpr uint32_t BITS_PER_BYTE = 8;
|
||||
constexpr uint32_t STATE_BYTES = 16;
|
||||
@@ -186,7 +186,7 @@ __host__ void host_integer_aes_ctr_256_encrypt(
|
||||
CudaRadixCiphertextFFI const *iv, CudaRadixCiphertextFFI const *round_keys,
|
||||
const Torus *counter_bits_le_all_blocks, uint32_t num_aes_inputs,
|
||||
int_aes_encrypt_buffer<Torus, KSTorus> *mem, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
constexpr uint32_t NUM_BITS = 128;
|
||||
|
||||
@@ -245,7 +245,7 @@ __host__ void host_integer_key_expansion_256(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *expanded_keys,
|
||||
CudaRadixCiphertextFFI const *key,
|
||||
int_key_expansion_256_buffer<Torus, KSTorus> *mem, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
constexpr uint32_t BITS_PER_WORD = 32;
|
||||
constexpr uint32_t BITS_PER_BYTE = 8;
|
||||
|
||||
@@ -23,10 +23,10 @@ __host__ uint64_t scratch_cuda_integer_abs(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_abs(CudaStreams streams, CudaRadixCiphertextFFI *ct,
|
||||
void *const *bsks, uint64_t *const *ksks,
|
||||
int_abs_buffer<uint64_t, uint64_t> *mem_ptr,
|
||||
void *const *bsks, KSTorus *const *ksks,
|
||||
int_abs_buffer<Torus, KSTorus> *mem_ptr,
|
||||
bool is_signed) {
|
||||
if (!is_signed)
|
||||
return;
|
||||
@@ -40,7 +40,7 @@ __host__ void host_integer_abs(CudaStreams streams, CudaRadixCiphertextFFI *ct,
|
||||
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
|
||||
mask, ct);
|
||||
|
||||
host_arithmetic_scalar_shift_inplace<Torus>(
|
||||
host_arithmetic_scalar_shift_inplace<Torus, KSTorus>(
|
||||
streams, mask, num_bits_in_ciphertext - 1,
|
||||
mem_ptr->arithmetic_scalar_shift_mem, bsks, ksks);
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), ct, mask, ct,
|
||||
@@ -49,11 +49,12 @@ __host__ void host_integer_abs(CudaStreams streams, CudaRadixCiphertextFFI *ct,
|
||||
|
||||
uint32_t requested_flag = outputFlag::FLAG_NONE;
|
||||
uint32_t uses_carry = 0;
|
||||
host_propagate_single_carry<Torus>(streams, ct, nullptr, nullptr,
|
||||
mem_ptr->scp_mem, bsks, ksks,
|
||||
requested_flag, uses_carry);
|
||||
host_propagate_single_carry<Torus, KSTorus>(streams, ct, nullptr, nullptr,
|
||||
mem_ptr->scp_mem, bsks, ksks,
|
||||
requested_flag, uses_carry);
|
||||
|
||||
host_bitop<Torus>(streams, ct, mask, ct, mem_ptr->bitxor_mem, bsks, ksks);
|
||||
host_bitop<Torus, KSTorus>(streams, ct, mask, ct, mem_ptr->bitxor_mem, bsks,
|
||||
ksks);
|
||||
}
|
||||
|
||||
#endif // TFHE_RS_ABS_CUH
|
||||
|
||||
@@ -63,6 +63,26 @@ uint64_t scratch_cuda_boolean_bitnot_64(
|
||||
lwe_ciphertext_count, is_unchecked, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_boolean_bitnot_64_ks32(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
|
||||
uint32_t lwe_ciphertext_count, bool is_unchecked, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
return scratch_cuda_boolean_bitnot<uint64_t, uint32_t>(
|
||||
CudaStreams(streams),
|
||||
(boolean_bitnot_buffer<uint64_t, uint32_t> **)mem_ptr, params,
|
||||
lwe_ciphertext_count, is_unchecked, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_boolean_bitnot_ciphertext_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *lwe_array,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
@@ -73,6 +93,16 @@ void cuda_boolean_bitnot_ciphertext_64(CudaStreamsFFI streams,
|
||||
(uint64_t **)(ksks));
|
||||
}
|
||||
|
||||
void cuda_boolean_bitnot_ciphertext_64_ks32(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *lwe_array,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
host_boolean_bitnot<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), lwe_array,
|
||||
(boolean_bitnot_buffer<uint64_t, uint32_t> *)mem_ptr, bsks,
|
||||
(uint32_t **)(ksks));
|
||||
}
|
||||
|
||||
void cleanup_cuda_boolean_bitnot(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
|
||||
@@ -102,6 +132,25 @@ uint64_t scratch_cuda_bitop_64(
|
||||
lwe_ciphertext_count, params, op_type, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_bitop_64_ks32(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, BITOP_TYPE op_type,
|
||||
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
return scratch_cuda_bitop<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), (int_bitop_buffer<uint64_t, uint32_t> **)mem_ptr,
|
||||
lwe_ciphertext_count, params, op_type, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_bitnot_ciphertext_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *radix_ciphertext,
|
||||
uint32_t ct_message_modulus,
|
||||
@@ -126,6 +175,19 @@ void cuda_bitop_ciphertext_64(CudaStreamsFFI streams,
|
||||
(uint64_t **)(ksks));
|
||||
}
|
||||
|
||||
void cuda_bitop_ciphertext_64_ks32(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_1,
|
||||
CudaRadixCiphertextFFI const *lwe_array_2,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
|
||||
host_bitop<uint64_t>(CudaStreams(streams), lwe_array_out, lwe_array_1,
|
||||
lwe_array_2,
|
||||
(int_bitop_buffer<uint64_t, uint32_t> *)mem_ptr, bsks,
|
||||
(uint32_t **)(ksks));
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_bitop(CudaStreamsFFI streams, int8_t **mem_ptr_void) {
|
||||
|
||||
int_bitop_buffer<uint64_t, uint64_t> *mem_ptr =
|
||||
|
||||
@@ -49,6 +49,28 @@ uint64_t scratch_cuda_cast_to_unsigned_64(
|
||||
requires_full_propagate, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_cast_to_unsigned_64_ks32(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed,
|
||||
bool requires_full_propagate, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
return scratch_cuda_cast_to_unsigned<uint64_t, uint32_t>(
|
||||
CudaStreams(streams),
|
||||
(int_cast_to_unsigned_buffer<uint64_t, uint32_t> **)mem_ptr, params,
|
||||
num_input_blocks, target_num_blocks, input_is_signed,
|
||||
requires_full_propagate, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI *input, int8_t *mem_ptr,
|
||||
@@ -61,6 +83,19 @@ void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
|
||||
target_num_blocks, input_is_signed, bsks, (uint64_t **)ksks);
|
||||
}
|
||||
|
||||
void cuda_cast_to_unsigned_64_ks32(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI *input,
|
||||
int8_t *mem_ptr, uint32_t target_num_blocks,
|
||||
bool input_is_signed, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
|
||||
host_cast_to_unsigned<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), output, input,
|
||||
(int_cast_to_unsigned_buffer<uint64_t, uint32_t> *)mem_ptr,
|
||||
target_num_blocks, input_is_signed, bsks, (uint32_t **)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
int_cast_to_unsigned_buffer<uint64_t, uint64_t> *mem_ptr =
|
||||
|
||||
@@ -130,7 +130,7 @@ host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI *input,
|
||||
int_cast_to_unsigned_buffer<Torus, KSTorus> *mem_ptr,
|
||||
uint32_t target_num_blocks, bool input_is_signed,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
uint32_t current_num_blocks = input->num_radix_blocks;
|
||||
|
||||
@@ -143,9 +143,9 @@ host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output,
|
||||
uint32_t num_blocks_to_add = target_num_blocks - current_num_blocks;
|
||||
|
||||
if (input_is_signed) {
|
||||
host_extend_radix_with_sign_msb<Torus>(
|
||||
host_extend_radix_with_sign_msb<Torus, KSTorus>(
|
||||
streams, output, input, mem_ptr->extend_buffer, num_blocks_to_add,
|
||||
bsks, (Torus **)ksks);
|
||||
bsks, ksks);
|
||||
} else {
|
||||
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(output, input,
|
||||
streams);
|
||||
|
||||
@@ -26,6 +26,30 @@ uint64_t scratch_cuda_cmux_64(CudaStreamsFFI streams, int8_t **mem_ptr,
|
||||
return ret;
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_cmux_64_ks32(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
PUSH_RANGE("scratch cmux")
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
std::function<uint64_t(uint64_t)> predicate_lut_f =
|
||||
[](uint64_t x) -> uint64_t { return x == 1; };
|
||||
|
||||
uint64_t ret = scratch_cuda_cmux<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), (int_cmux_buffer<uint64_t, uint32_t> **)mem_ptr,
|
||||
predicate_lut_f, lwe_ciphertext_count, params, allocate_gpu_memory);
|
||||
POP_RANGE()
|
||||
return ret;
|
||||
}
|
||||
|
||||
void cuda_cmux_ciphertext_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_condition,
|
||||
@@ -34,10 +58,25 @@ void cuda_cmux_ciphertext_64(CudaStreamsFFI streams,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
PUSH_RANGE("cmux")
|
||||
host_cmux<uint64_t>(CudaStreams(streams), lwe_array_out, lwe_condition,
|
||||
lwe_array_true, lwe_array_false,
|
||||
(int_cmux_buffer<uint64_t, uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)(ksks));
|
||||
host_cmux<uint64_t, uint64_t>(CudaStreams(streams), lwe_array_out,
|
||||
lwe_condition, lwe_array_true, lwe_array_false,
|
||||
(int_cmux_buffer<uint64_t, uint64_t> *)mem_ptr,
|
||||
bsks, (uint64_t **)(ksks));
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
void cuda_cmux_ciphertext_64_ks32(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_condition,
|
||||
CudaRadixCiphertextFFI const *lwe_array_true,
|
||||
CudaRadixCiphertextFFI const *lwe_array_false,
|
||||
int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
PUSH_RANGE("cmux")
|
||||
host_cmux<uint64_t, uint32_t>(CudaStreams(streams), lwe_array_out,
|
||||
lwe_condition, lwe_array_true, lwe_array_false,
|
||||
(int_cmux_buffer<uint64_t, uint32_t> *)mem_ptr,
|
||||
bsks, (uint32_t **)(ksks));
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
|
||||
@@ -39,6 +39,45 @@ uint64_t scratch_cuda_comparison_64(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_comparison_64_ks32(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_radix_blocks, uint32_t message_modulus, uint32_t carry_modulus,
|
||||
PBS_TYPE pbs_type, COMPARISON_TYPE op_type, bool is_signed,
|
||||
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
PUSH_RANGE("scratch comparison")
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
uint64_t size_tracker = 0;
|
||||
switch (op_type) {
|
||||
case EQ:
|
||||
case NE:
|
||||
size_tracker += scratch_cuda_comparison_check<uint64_t, uint32_t>(
|
||||
CudaStreams(streams),
|
||||
(int_comparison_buffer<uint64_t, uint32_t> **)mem_ptr, num_radix_blocks,
|
||||
params, op_type, false, allocate_gpu_memory);
|
||||
break;
|
||||
case GT:
|
||||
case GE:
|
||||
case LT:
|
||||
case LE:
|
||||
case MAX:
|
||||
case MIN:
|
||||
size_tracker += scratch_cuda_comparison_check<uint64_t, uint32_t>(
|
||||
CudaStreams(streams),
|
||||
(int_comparison_buffer<uint64_t, uint32_t> **)mem_ptr, num_radix_blocks,
|
||||
params, op_type, is_signed, allocate_gpu_memory);
|
||||
break;
|
||||
}
|
||||
POP_RANGE()
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
void cuda_comparison_ciphertext_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_1,
|
||||
@@ -87,6 +126,53 @@ void cuda_comparison_ciphertext_64(CudaStreamsFFI streams,
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
void cuda_comparison_ciphertext_64_ks32(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_1,
|
||||
CudaRadixCiphertextFFI const *lwe_array_2, int8_t *mem_ptr,
|
||||
void *const *bsks, void *const *ksks) {
|
||||
PUSH_RANGE("comparison")
|
||||
if (lwe_array_1->num_radix_blocks != lwe_array_2->num_radix_blocks)
|
||||
PANIC("Cuda error: input num radix blocks must be the same")
|
||||
// The output ciphertext might be a boolean block or a radix ciphertext
|
||||
// depending on the case (eq/gt vs max/min) so the amount of blocks to
|
||||
// consider for calculation is the one of the input
|
||||
auto num_radix_blocks = lwe_array_1->num_radix_blocks;
|
||||
int_comparison_buffer<uint64_t, uint32_t> *buffer =
|
||||
(int_comparison_buffer<uint64_t, uint32_t> *)mem_ptr;
|
||||
switch (buffer->op) {
|
||||
case EQ:
|
||||
case NE:
|
||||
host_equality_check<uint64_t>(CudaStreams(streams), lwe_array_out,
|
||||
lwe_array_1, lwe_array_2, buffer, bsks,
|
||||
(uint32_t **)(ksks), num_radix_blocks);
|
||||
break;
|
||||
case GT:
|
||||
case GE:
|
||||
case LT:
|
||||
case LE:
|
||||
if (num_radix_blocks % 2 != 0)
|
||||
PANIC("Cuda error (comparisons): the number of radix blocks has to be "
|
||||
"even.")
|
||||
host_difference_check<uint64_t>(CudaStreams(streams), lwe_array_out,
|
||||
lwe_array_1, lwe_array_2, buffer,
|
||||
buffer->diff_buffer->operator_f, bsks,
|
||||
(uint32_t **)(ksks), num_radix_blocks);
|
||||
break;
|
||||
case MAX:
|
||||
case MIN:
|
||||
if (num_radix_blocks % 2 != 0)
|
||||
PANIC("Cuda error (max/min): the number of radix blocks has to be even.")
|
||||
host_maxmin<uint64_t>(CudaStreams(streams), lwe_array_out, lwe_array_1,
|
||||
lwe_array_2, buffer, bsks, (uint32_t **)(ksks),
|
||||
num_radix_blocks);
|
||||
break;
|
||||
default:
|
||||
PANIC("Cuda error: integer operation not supported")
|
||||
}
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_comparison(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
PUSH_RANGE("cleanup comparison")
|
||||
|
||||
@@ -21,6 +21,27 @@ uint64_t scratch_cuda_integer_div_rem_radix_ciphertext_64(
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_integer_div_rem_radix_ciphertext_64_ks32(
|
||||
CudaStreamsFFI streams, bool is_signed, int8_t **mem_ptr,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t big_lwe_dimension, uint32_t small_lwe_dimension, uint32_t ks_level,
|
||||
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
|
||||
uint32_t grouping_factor, uint32_t num_blocks, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
PUSH_RANGE("scratch div")
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
return scratch_cuda_integer_div_rem<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), is_signed,
|
||||
(int_div_rem_memory<uint64_t, uint32_t> **)mem_ptr, num_blocks, params,
|
||||
allocate_gpu_memory);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
void cuda_integer_div_rem_radix_ciphertext_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *quotient,
|
||||
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
|
||||
@@ -35,6 +56,20 @@ void cuda_integer_div_rem_radix_ciphertext_64(
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
void cuda_integer_div_rem_radix_ciphertext_64_ks32(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *quotient,
|
||||
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
|
||||
CudaRadixCiphertextFFI const *divisor, bool is_signed, int8_t *mem_ptr,
|
||||
void *const *bsks, void *const *ksks) {
|
||||
PUSH_RANGE("div")
|
||||
auto mem = (int_div_rem_memory<uint64_t, uint32_t> *)mem_ptr;
|
||||
|
||||
host_integer_div_rem<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), quotient, remainder, numerator, divisor, is_signed,
|
||||
bsks, (uint32_t **)(ksks), mem);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_div_rem(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
PUSH_RANGE("cleanup div")
|
||||
|
||||
@@ -31,8 +31,8 @@ __host__ void host_unsigned_integer_div_rem_block_by_block_2_2(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *quotient,
|
||||
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
|
||||
CudaRadixCiphertextFFI const *divisor, void *const *bsks,
|
||||
uint64_t *const *ksks,
|
||||
unsigned_int_div_rem_2_2_memory<uint64_t, uint64_t> *mem_ptr) {
|
||||
KSTorus *const *ksks,
|
||||
unsigned_int_div_rem_2_2_memory<Torus, KSTorus> *mem_ptr) {
|
||||
|
||||
if (streams.count() < 4) {
|
||||
PANIC("GPU count should be greater than 4 when using div_rem_2_2");
|
||||
@@ -476,8 +476,8 @@ __host__ void host_unsigned_integer_div_rem(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *quotient,
|
||||
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
|
||||
CudaRadixCiphertextFFI const *divisor, void *const *bsks,
|
||||
uint64_t *const *ksks,
|
||||
unsigned_int_div_rem_memory<uint64_t, uint64_t> *mem_ptr) {
|
||||
KSTorus *const *ksks,
|
||||
unsigned_int_div_rem_memory<Torus, KSTorus> *mem_ptr) {
|
||||
|
||||
if (remainder->num_radix_blocks != numerator->num_radix_blocks ||
|
||||
remainder->num_radix_blocks != divisor->num_radix_blocks ||
|
||||
@@ -740,7 +740,7 @@ __host__ void host_unsigned_integer_div_rem(
|
||||
mem_ptr->overflow_sub_mem->update_lut_indexes(
|
||||
streams, first_indexes, second_indexes, scalar_indexes,
|
||||
merged_interesting_remainder->num_radix_blocks);
|
||||
host_integer_overflowing_sub<uint64_t>(
|
||||
host_integer_overflowing_sub<Torus, KSTorus>(
|
||||
streams, new_remainder, merged_interesting_remainder,
|
||||
interesting_divisor, subtraction_overflowed,
|
||||
(const CudaRadixCiphertextFFI *)nullptr, mem_ptr->overflow_sub_mem,
|
||||
@@ -903,13 +903,11 @@ __host__ void host_unsigned_integer_div_rem(
|
||||
}
|
||||
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void
|
||||
host_integer_div_rem(CudaStreams streams, CudaRadixCiphertextFFI *quotient,
|
||||
CudaRadixCiphertextFFI *remainder,
|
||||
CudaRadixCiphertextFFI const *numerator,
|
||||
CudaRadixCiphertextFFI const *divisor, bool is_signed,
|
||||
void *const *bsks, uint64_t *const *ksks,
|
||||
int_div_rem_memory<uint64_t, uint64_t> *int_mem_ptr) {
|
||||
__host__ void host_integer_div_rem(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *quotient,
|
||||
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
|
||||
CudaRadixCiphertextFFI const *divisor, bool is_signed, void *const *bsks,
|
||||
KSTorus *const *ksks, int_div_rem_memory<Torus, KSTorus> *int_mem_ptr) {
|
||||
if (remainder->num_radix_blocks != numerator->num_radix_blocks ||
|
||||
remainder->num_radix_blocks != divisor->num_radix_blocks ||
|
||||
remainder->num_radix_blocks != quotient->num_radix_blocks)
|
||||
|
||||
@@ -116,6 +116,25 @@ uint64_t scratch_cuda_integer_overflowing_sub_64_inplace(
|
||||
params, compute_overflow, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_integer_overflowing_sub_64_ks32_inplace(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
|
||||
PBS_TYPE pbs_type, uint32_t compute_overflow, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
return scratch_cuda_integer_overflowing_sub<uint64_t, uint32_t>(
|
||||
CudaStreams(streams),
|
||||
(int_borrow_prop_memory<uint64_t, uint32_t> **)mem_ptr, num_blocks,
|
||||
params, compute_overflow, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_propagate_single_carry_64_inplace(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array,
|
||||
CudaRadixCiphertextFFI *carry_out, const CudaRadixCiphertextFFI *carry_in,
|
||||
@@ -128,6 +147,18 @@ void cuda_propagate_single_carry_64_inplace(
|
||||
(uint64_t **)(ksks), requested_flag, uses_carry);
|
||||
}
|
||||
|
||||
void cuda_propagate_single_carry_64_ks32_inplace(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array,
|
||||
CudaRadixCiphertextFFI *carry_out, const CudaRadixCiphertextFFI *carry_in,
|
||||
int8_t *mem_ptr, void *const *bsks, void *const *ksks,
|
||||
uint32_t requested_flag, uint32_t uses_carry) {
|
||||
|
||||
host_propagate_single_carry<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), lwe_array, carry_out, carry_in,
|
||||
(int_sc_prop_memory<uint64_t, uint32_t> *)mem_ptr, bsks,
|
||||
(uint32_t **)(ksks), requested_flag, uses_carry);
|
||||
}
|
||||
|
||||
void cuda_add_and_propagate_single_carry_64_inplace(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
|
||||
const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out,
|
||||
@@ -146,10 +177,10 @@ void cuda_add_and_propagate_single_carry_64_ks32_inplace(
|
||||
const CudaRadixCiphertextFFI *carry_in, int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks, uint32_t requested_flag, uint32_t uses_carry) {
|
||||
|
||||
host_add_and_propagate_single_carry<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), lhs_array, rhs_array, carry_out, carry_in,
|
||||
(int_sc_prop_memory<uint64_t, uint32_t> *)mem_ptr, bsks,
|
||||
(uint32_t **)(ksks), requested_flag, uses_carry);
|
||||
host_add_and_propagate_single_carry<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), lhs_array, rhs_array, carry_out, carry_in,
|
||||
(int_sc_prop_memory<uint64_t, uint32_t> *)mem_ptr, bsks,
|
||||
(uint32_t **)(ksks), requested_flag, uses_carry);
|
||||
}
|
||||
|
||||
void cuda_integer_overflowing_sub_64_inplace(
|
||||
@@ -167,6 +198,21 @@ void cuda_integer_overflowing_sub_64_inplace(
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
void cuda_integer_overflowing_sub_64_ks32_inplace(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
|
||||
const CudaRadixCiphertextFFI *rhs_array,
|
||||
CudaRadixCiphertextFFI *overflow_block,
|
||||
const CudaRadixCiphertextFFI *input_borrow, int8_t *mem_ptr,
|
||||
void *const *bsks, void *const *ksks, uint32_t compute_overflow,
|
||||
uint32_t uses_input_borrow) {
|
||||
PUSH_RANGE("overflow sub")
|
||||
host_integer_overflowing_sub<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), lhs_array, lhs_array, rhs_array, overflow_block,
|
||||
input_borrow, (int_borrow_prop_memory<uint64_t, uint32_t> *)mem_ptr, bsks,
|
||||
(uint32_t **)ksks, compute_overflow, uses_input_borrow);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
void cleanup_cuda_propagate_single_carry(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
PUSH_RANGE("cleanup propagate sc")
|
||||
|
||||
@@ -115,7 +115,7 @@ void host_integer_grouped_oprf_custom_range(
|
||||
const Torus *decomposed_scalar, const Torus *has_at_least_one_set,
|
||||
uint32_t num_scalars, uint32_t shift,
|
||||
int_grouped_oprf_custom_range_memory<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
CudaRadixCiphertextFFI *computation_buffer = mem_ptr->tmp_oprf_output;
|
||||
set_zero_radix_ciphertext_slice_async<Torus>(
|
||||
|
||||
@@ -20,6 +20,26 @@ uint64_t scratch_cuda_sub_and_propagate_single_carry_64_inplace(
|
||||
requested_flag, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_sub_and_propagate_single_carry_64_ks32_inplace(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
|
||||
PBS_TYPE pbs_type, uint32_t requested_flag, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
return scratch_cuda_sub_and_propagate_single_carry<uint64_t, uint32_t>(
|
||||
CudaStreams(streams),
|
||||
(int_sub_and_propagate<uint64_t, uint32_t> **)mem_ptr, num_blocks, params,
|
||||
requested_flag, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_sub_and_propagate_single_carry_64_inplace(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
|
||||
const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out,
|
||||
@@ -33,6 +53,19 @@ void cuda_sub_and_propagate_single_carry_64_inplace(
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
void cuda_sub_and_propagate_single_carry_64_ks32_inplace(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
|
||||
const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out,
|
||||
const CudaRadixCiphertextFFI *carry_in, int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks, uint32_t requested_flag, uint32_t uses_carry) {
|
||||
PUSH_RANGE("sub")
|
||||
host_sub_and_propagate_single_carry<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), lhs_array, rhs_array, carry_out, carry_in,
|
||||
(int_sub_and_propagate<uint64_t, uint32_t> *)mem_ptr, bsks,
|
||||
(uint32_t **)(ksks), requested_flag, uses_carry);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
void cleanup_cuda_sub_and_propagate_single_carry(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
PUSH_RANGE("cleanup sub")
|
||||
|
||||
@@ -95,7 +95,8 @@ __host__ void host_integer_overflowing_sub(
|
||||
CudaRadixCiphertextFFI *overflow_block,
|
||||
const CudaRadixCiphertextFFI *input_borrow,
|
||||
int_borrow_prop_memory<Torus, KSTorus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t compute_overflow, uint32_t uses_input_borrow) {
|
||||
KSTorus *const *ksks, uint32_t compute_overflow,
|
||||
uint32_t uses_input_borrow) {
|
||||
PUSH_RANGE("overflowing sub")
|
||||
if (output->num_radix_blocks != input_left->num_radix_blocks ||
|
||||
output->num_radix_blocks != input_right->num_radix_blocks)
|
||||
@@ -124,8 +125,8 @@ __host__ void host_integer_overflowing_sub(
|
||||
|
||||
host_single_borrow_propagate<Torus>(
|
||||
streams, output, overflow_block, input_borrow,
|
||||
(int_borrow_prop_memory<Torus, KSTorus> *)mem_ptr, bsks, (Torus **)(ksks),
|
||||
num_groups, compute_overflow, uses_input_borrow);
|
||||
(int_borrow_prop_memory<Torus, KSTorus> *)mem_ptr, bsks,
|
||||
(KSTorus **)(ksks), num_groups, compute_overflow, uses_input_borrow);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ __host__ void host_unchecked_all_eq_slices(
|
||||
CudaRadixCiphertextFFI const *lhs, CudaRadixCiphertextFFI const *rhs,
|
||||
uint32_t num_inputs, uint32_t num_blocks,
|
||||
int_unchecked_all_eq_slices_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
// sync_from(streams)
|
||||
//
|
||||
@@ -104,7 +104,7 @@ __host__ void host_unchecked_contains_sub_slice(
|
||||
CudaRadixCiphertextFFI const *lhs, CudaRadixCiphertextFFI const *rhs,
|
||||
uint32_t num_rhs, uint32_t num_blocks,
|
||||
int_unchecked_contains_sub_slice_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
uint32_t num_windows = mem_ptr->num_windows;
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ __host__ void host_compute_equality_selectors(
|
||||
CudaRadixCiphertextFFI const *lwe_array_in, uint32_t num_blocks,
|
||||
const uint64_t *h_decomposed_cleartexts,
|
||||
int_equality_selectors_buffer<Torus, KSTorus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
uint32_t num_possible_values = mem_ptr->num_possible_values;
|
||||
uint32_t message_modulus = mem_ptr->params.message_modulus;
|
||||
@@ -91,7 +91,7 @@ __host__ void host_create_possible_results(
|
||||
CudaRadixCiphertextFFI const *lwe_array_in_list,
|
||||
uint32_t num_possible_values, const uint64_t *h_decomposed_cleartexts,
|
||||
uint32_t num_blocks, int_possible_results_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
uint32_t max_packed_value = mem_ptr->max_packed_value;
|
||||
uint32_t max_luts_per_call = mem_ptr->max_luts_per_call;
|
||||
@@ -174,7 +174,7 @@ __host__ void host_aggregate_one_hot_vector(
|
||||
CudaRadixCiphertextFFI const *lwe_array_in_list,
|
||||
uint32_t num_input_ciphertexts, uint32_t num_blocks,
|
||||
int_aggregate_one_hot_buffer<Torus, KSTorus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
int_radix_params params = mem_ptr->params;
|
||||
uint32_t chunk_size = mem_ptr->chunk_size;
|
||||
@@ -345,7 +345,7 @@ __host__ void host_unchecked_match_value(
|
||||
CudaRadixCiphertextFFI const *lwe_array_in_ct,
|
||||
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
|
||||
int_unchecked_match_buffer<Torus, KSTorus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
host_compute_equality_selectors<Torus>(
|
||||
streams, mem_ptr->selectors_list, lwe_array_in_ct,
|
||||
mem_ptr->num_input_blocks, h_match_inputs, mem_ptr->eq_selectors_buffer,
|
||||
@@ -422,7 +422,7 @@ __host__ void host_unchecked_match_value_or(
|
||||
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
|
||||
const uint64_t *h_or_value,
|
||||
int_unchecked_match_value_or_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
host_unchecked_match_value<Torus>(streams, mem_ptr->tmp_match_result,
|
||||
mem_ptr->tmp_match_bool, lwe_array_in_ct,
|
||||
@@ -465,7 +465,7 @@ host_unchecked_contains(CudaStreams streams, CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *value,
|
||||
uint32_t num_inputs, uint32_t num_blocks,
|
||||
int_unchecked_contains_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
mem_ptr->internal_cuda_streams.internal_streams_wait_for_main_stream_0(
|
||||
streams);
|
||||
@@ -516,7 +516,7 @@ __host__ void host_unchecked_contains_clear(
|
||||
CudaRadixCiphertextFFI const *inputs, const uint64_t *h_clear_val,
|
||||
uint32_t num_inputs, uint32_t num_blocks,
|
||||
int_unchecked_contains_clear_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
cuda_memcpy_async_to_gpu(mem_ptr->d_clear_val, h_clear_val,
|
||||
num_blocks * sizeof(Torus), streams.stream(0),
|
||||
@@ -577,7 +577,7 @@ __host__ void host_unchecked_is_in_clears(
|
||||
CudaRadixCiphertextFFI const *input, const uint64_t *h_cleartexts,
|
||||
uint32_t num_clears, uint32_t num_blocks,
|
||||
int_unchecked_is_in_clears_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
host_compute_equality_selectors<Torus>(streams, mem_ptr->unpacked_selectors,
|
||||
input, num_blocks, h_cleartexts,
|
||||
@@ -594,7 +594,7 @@ __host__ void host_compute_final_index_from_selectors(
|
||||
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *selectors,
|
||||
uint32_t num_inputs, uint32_t num_blocks_index,
|
||||
int_final_index_from_selectors_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
for (uint32_t i = 0; i < num_inputs; i++) {
|
||||
CudaRadixCiphertextFFI const *src_selector = &selectors[i];
|
||||
@@ -657,7 +657,7 @@ __host__ void host_unchecked_index_in_clears(
|
||||
const uint64_t *h_cleartexts, uint32_t num_clears, uint32_t num_blocks,
|
||||
uint32_t num_blocks_index,
|
||||
int_unchecked_index_in_clears_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
host_compute_equality_selectors<Torus>(
|
||||
streams, mem_ptr->final_index_buf->unpacked_selectors, input, num_blocks,
|
||||
@@ -704,7 +704,7 @@ __host__ void host_unchecked_first_index_in_clears(
|
||||
const uint64_t *h_unique_values, const uint64_t *h_unique_indices,
|
||||
uint32_t num_unique, uint32_t num_blocks, uint32_t num_blocks_index,
|
||||
int_unchecked_first_index_in_clears_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
host_compute_equality_selectors<Torus>(streams, mem_ptr->unpacked_selectors,
|
||||
input, num_blocks, h_unique_values,
|
||||
@@ -748,7 +748,7 @@ __host__ void host_unchecked_first_index_of_clear(
|
||||
const uint64_t *h_clear_val, uint32_t num_inputs, uint32_t num_blocks,
|
||||
uint32_t num_blocks_index,
|
||||
int_unchecked_first_index_of_clear_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
cuda_memcpy_async_to_gpu(mem_ptr->d_clear_val, h_clear_val,
|
||||
num_blocks * sizeof(Torus), streams.stream(0),
|
||||
@@ -841,7 +841,7 @@ __host__ void host_unchecked_first_index_of(
|
||||
CudaRadixCiphertextFFI const *value, uint32_t num_inputs,
|
||||
uint32_t num_blocks, uint32_t num_blocks_index,
|
||||
int_unchecked_first_index_of_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
mem_ptr->internal_cuda_streams.internal_streams_wait_for_main_stream_0(
|
||||
streams);
|
||||
@@ -924,7 +924,7 @@ __host__ void host_unchecked_index_of(
|
||||
CudaRadixCiphertextFFI const *value, uint32_t num_inputs,
|
||||
uint32_t num_blocks, uint32_t num_blocks_index,
|
||||
int_unchecked_index_of_buffer<Torus, KSTorus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
mem_ptr->internal_cuda_streams.internal_streams_wait_for_main_stream_0(
|
||||
streams);
|
||||
@@ -992,7 +992,7 @@ __host__ void host_unchecked_index_of_clear(
|
||||
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks,
|
||||
uint32_t num_blocks_index,
|
||||
int_unchecked_index_of_clear_buffer<Torus, KSTorus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
CudaRadixCiphertextFFI *packed_selectors =
|
||||
mem_ptr->final_index_buf->packed_selectors;
|
||||
|
||||
@@ -651,6 +651,29 @@ unsafe extern "C" {
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_comparison_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
lwe_ciphertext_count: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
op_type: COMPARISON_TYPE,
|
||||
is_signed: bool,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_comparison_ciphertext_64(
|
||||
streams: CudaStreamsFFI,
|
||||
@@ -662,6 +685,17 @@ unsafe extern "C" {
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_comparison_ciphertext_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
lwe_array_out: *mut CudaRadixCiphertextFFI,
|
||||
lwe_array_1: *const CudaRadixCiphertextFFI,
|
||||
lwe_array_2: *const CudaRadixCiphertextFFI,
|
||||
mem_ptr: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_scalar_comparison_ciphertext_64(
|
||||
streams: CudaStreamsFFI,
|
||||
@@ -746,6 +780,37 @@ unsafe extern "C" {
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_boolean_bitnot_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
lwe_ciphertext_count: u32,
|
||||
is_unchecked: bool,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_boolean_bitnot_ciphertext_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
lwe_array: *mut CudaRadixCiphertextFFI,
|
||||
mem_ptr: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_boolean_bitnot(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
@@ -780,6 +845,28 @@ unsafe extern "C" {
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_bitop_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
lwe_ciphertext_count: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
op_type: BITOP_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_scalar_bitop_ciphertext_64(
|
||||
streams: CudaStreamsFFI,
|
||||
@@ -804,6 +891,17 @@ unsafe extern "C" {
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_bitop_ciphertext_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
lwe_array_out: *mut CudaRadixCiphertextFFI,
|
||||
lwe_array_1: *const CudaRadixCiphertextFFI,
|
||||
lwe_array_2: *const CudaRadixCiphertextFFI,
|
||||
mem_ptr: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_integer_bitop(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
@@ -828,6 +926,27 @@ unsafe extern "C" {
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_cmux_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
lwe_ciphertext_count: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_cmux_ciphertext_64(
|
||||
streams: CudaStreamsFFI,
|
||||
@@ -840,6 +959,18 @@ unsafe extern "C" {
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_cmux_ciphertext_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
lwe_array_out: *mut CudaRadixCiphertextFFI,
|
||||
lwe_condition: *const CudaRadixCiphertextFFI,
|
||||
lwe_array_true: *const CudaRadixCiphertextFFI,
|
||||
lwe_array_false: *const CudaRadixCiphertextFFI,
|
||||
mem_ptr: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_cmux(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
@@ -1016,6 +1147,28 @@ unsafe extern "C" {
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_integer_overflowing_sub_64_ks32_inplace(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
num_blocks: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
compute_overflow: u32,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_integer_overflowing_sub_64_inplace(
|
||||
streams: CudaStreamsFFI,
|
||||
@@ -1030,6 +1183,20 @@ unsafe extern "C" {
|
||||
uses_input_borrow: u32,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_integer_overflowing_sub_64_ks32_inplace(
|
||||
streams: CudaStreamsFFI,
|
||||
lhs_array: *mut CudaRadixCiphertextFFI,
|
||||
rhs_array: *const CudaRadixCiphertextFFI,
|
||||
overflow_block: *mut CudaRadixCiphertextFFI,
|
||||
input_borrow: *const CudaRadixCiphertextFFI,
|
||||
mem_ptr: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
compute_overflow: u32,
|
||||
uses_input_borrow: u32,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_integer_overflowing_sub(
|
||||
streams: CudaStreamsFFI,
|
||||
@@ -1169,6 +1336,28 @@ unsafe extern "C" {
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_integer_div_rem_radix_ciphertext_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
is_signed: bool,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
num_blocks: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_integer_div_rem_radix_ciphertext_64(
|
||||
streams: CudaStreamsFFI,
|
||||
@@ -1182,6 +1371,19 @@ unsafe extern "C" {
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_integer_div_rem_radix_ciphertext_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
quotient: *mut CudaRadixCiphertextFFI,
|
||||
remainder: *mut CudaRadixCiphertextFFI,
|
||||
numerator: *const CudaRadixCiphertextFFI,
|
||||
divisor: *const CudaRadixCiphertextFFI,
|
||||
is_signed: bool,
|
||||
mem_ptr: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_integer_div_rem(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
@@ -1826,6 +2028,30 @@ unsafe extern "C" {
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_cast_to_unsigned_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
num_input_blocks: u32,
|
||||
target_num_blocks: u32,
|
||||
input_is_signed: bool,
|
||||
requires_full_propagate: bool,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_cast_to_unsigned_64(
|
||||
streams: CudaStreamsFFI,
|
||||
@@ -1838,6 +2064,18 @@ unsafe extern "C" {
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_cast_to_unsigned_64_ks32(
|
||||
streams: CudaStreamsFFI,
|
||||
output: *mut CudaRadixCiphertextFFI,
|
||||
input: *mut CudaRadixCiphertextFFI,
|
||||
mem_ptr: *mut i8,
|
||||
target_num_blocks: u32,
|
||||
input_is_signed: bool,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_cast_to_unsigned_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
|
||||
@@ -866,6 +866,9 @@ fn main() {
|
||||
ParamType::Classical => {
|
||||
benchmark::params_aliases::BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into()
|
||||
},
|
||||
ParamType::ClassicalKs32 => {
|
||||
benchmark::params_aliases::BENCH_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M64.into()
|
||||
},
|
||||
_ => {
|
||||
benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into()
|
||||
}
|
||||
|
||||
@@ -2719,7 +2719,13 @@ mod cuda {
|
||||
rng_func: default_scalar
|
||||
);
|
||||
|
||||
criterion_group!(cuda_ops_support_ks32, cuda_mul, cuda_add, cuda_sub,);
|
||||
criterion_group!(
|
||||
cuda_ops_support_ks32,
|
||||
cuda_mul,
|
||||
cuda_add,
|
||||
cuda_sub,
|
||||
cuda_div_rem,
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
unchecked_cuda_ops,
|
||||
|
||||
@@ -1492,35 +1492,70 @@ pub(crate) unsafe fn cuda_backend_unchecked_bitop_assign<T: UnsignedInteger, B:
|
||||
&mut radix_lwe_right_degrees,
|
||||
&mut radix_lwe_right_noise_levels,
|
||||
);
|
||||
scratch_cuda_bitop_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
op as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
cuda_bitop_ciphertext_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_right,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
if TypeId::of::<T>() == TypeId::of::<u64>() {
|
||||
scratch_cuda_bitop_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
op as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
cuda_bitop_ciphertext_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_right,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
} else if TypeId::of::<T>() == TypeId::of::<u32>() {
|
||||
scratch_cuda_bitop_64_ks32(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
op as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
cuda_bitop_ciphertext_64_ks32(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_right,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
} else {
|
||||
panic!("Unsupported KS dtype");
|
||||
}
|
||||
cleanup_cuda_integer_bitop(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
update_noise_degree(radix_lwe_left, &cuda_ffi_radix_lwe_left);
|
||||
}
|
||||
@@ -2093,37 +2128,73 @@ pub(crate) unsafe fn cuda_backend_unchecked_comparison<T: UnsignedInteger, B: Nu
|
||||
&mut radix_lwe_right_noise_levels,
|
||||
);
|
||||
|
||||
scratch_cuda_comparison_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
op as u32,
|
||||
is_signed,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
if TypeId::of::<T>() == TypeId::of::<u64>() {
|
||||
scratch_cuda_comparison_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
op as u32,
|
||||
is_signed,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_comparison_ciphertext_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_out,
|
||||
&raw const cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_right,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
cuda_comparison_ciphertext_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_out,
|
||||
&raw const cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_right,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
} else if TypeId::of::<T>() == TypeId::of::<u32>() {
|
||||
scratch_cuda_comparison_64_ks32(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
op as u32,
|
||||
is_signed,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_comparison_ciphertext_64_ks32(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_out,
|
||||
&raw const cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_right,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
} else {
|
||||
panic!("Unsupported KS datatype");
|
||||
}
|
||||
|
||||
cleanup_cuda_integer_comparison(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
update_noise_degree(radix_lwe_out, &cuda_ffi_radix_lwe_out);
|
||||
@@ -5565,35 +5636,71 @@ pub(crate) unsafe fn cuda_backend_unchecked_cmux<T: UnsignedInteger, B: Numeric>
|
||||
&mut condition_noise_levels,
|
||||
);
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
scratch_cuda_cmux_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
cuda_cmux_ciphertext_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_out,
|
||||
&raw const cuda_ffi_condition,
|
||||
&raw const cuda_ffi_radix_lwe_true,
|
||||
&raw const cuda_ffi_radix_lwe_false,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
if TypeId::of::<T>() == TypeId::of::<u64>() {
|
||||
scratch_cuda_cmux_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
cuda_cmux_ciphertext_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_out,
|
||||
&raw const cuda_ffi_condition,
|
||||
&raw const cuda_ffi_radix_lwe_true,
|
||||
&raw const cuda_ffi_radix_lwe_false,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
} else if TypeId::of::<T>() == TypeId::of::<u32>() {
|
||||
scratch_cuda_cmux_64_ks32(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
cuda_cmux_ciphertext_64_ks32(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_out,
|
||||
&raw const cuda_ffi_condition,
|
||||
&raw const cuda_ffi_radix_lwe_true,
|
||||
&raw const cuda_ffi_radix_lwe_false,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
} else {
|
||||
panic!("Unsupported KS dtype");
|
||||
}
|
||||
|
||||
cleanup_cuda_cmux(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
update_noise_degree(radix_lwe_out, &cuda_ffi_radix_lwe_out);
|
||||
}
|
||||
@@ -6963,38 +7070,76 @@ pub(crate) unsafe fn cuda_backend_unchecked_unsigned_overflowing_sub_assign<
|
||||
.collect();
|
||||
let cuda_ffi_carry_in =
|
||||
prepare_cuda_radix_ffi(carry_in, &mut carry_in_degrees, &mut carry_in_noise_levels);
|
||||
scratch_cuda_integer_overflowing_sub_64_inplace(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension,
|
||||
lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
compute_overflow as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
cuda_integer_overflowing_sub_64_inplace(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_right,
|
||||
&raw mut cuda_ffi_carry_out,
|
||||
&raw const cuda_ffi_carry_in,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
compute_overflow as u32,
|
||||
uses_input_borrow,
|
||||
);
|
||||
|
||||
if TypeId::of::<T>() == TypeId::of::<u64>() {
|
||||
scratch_cuda_integer_overflowing_sub_64_inplace(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension,
|
||||
lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
compute_overflow as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
cuda_integer_overflowing_sub_64_inplace(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_right,
|
||||
&raw mut cuda_ffi_carry_out,
|
||||
&raw const cuda_ffi_carry_in,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
compute_overflow as u32,
|
||||
uses_input_borrow,
|
||||
);
|
||||
} else if TypeId::of::<T>() == TypeId::of::<u32>() {
|
||||
scratch_cuda_integer_overflowing_sub_64_ks32_inplace(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension,
|
||||
lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
compute_overflow as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
cuda_integer_overflowing_sub_64_ks32_inplace(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_lwe_left,
|
||||
&raw const cuda_ffi_radix_lwe_right,
|
||||
&raw mut cuda_ffi_carry_out,
|
||||
&raw const cuda_ffi_carry_in,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
compute_overflow as u32,
|
||||
uses_input_borrow,
|
||||
);
|
||||
} else {
|
||||
panic!("Unsupported KS dtype");
|
||||
}
|
||||
cleanup_cuda_integer_overflowing_sub(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
update_noise_degree(radix_lwe_left, &cuda_ffi_radix_lwe_left);
|
||||
update_noise_degree(carry_out, &cuda_ffi_carry_out);
|
||||
@@ -8291,34 +8436,68 @@ pub(crate) unsafe fn cuda_backend_boolean_bitnot_assign<T: UnsignedInteger, B: N
|
||||
&mut ciphertext_noise_levels,
|
||||
);
|
||||
|
||||
scratch_cuda_boolean_bitnot_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
1u32,
|
||||
is_unchecked,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
if TypeId::of::<T>() == TypeId::of::<u64>() {
|
||||
scratch_cuda_boolean_bitnot_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
1u32,
|
||||
is_unchecked,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_boolean_bitnot_ciphertext_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_ciphertext,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
} else if TypeId::of::<T>() == TypeId::of::<u32>() {
|
||||
scratch_cuda_boolean_bitnot_64_ks32(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
1u32,
|
||||
is_unchecked,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_boolean_bitnot_ciphertext_64_ks32(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_ciphertext,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
} else {
|
||||
panic!("Unsupported KS dype");
|
||||
}
|
||||
|
||||
cuda_boolean_bitnot_ciphertext_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_ciphertext,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
cleanup_cuda_boolean_bitnot(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
update_noise_degree(ciphertext, &cuda_ffi_ciphertext);
|
||||
}
|
||||
@@ -8792,39 +8971,75 @@ pub(crate) unsafe fn cuda_backend_cast_to_unsigned<T: UnsignedInteger, B: Numeri
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
scratch_cuda_cast_to_unsigned_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_input_blocks,
|
||||
target_num_blocks,
|
||||
input_is_signed,
|
||||
requires_full_propagate,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
if TypeId::of::<T>() == TypeId::of::<u32>() {
|
||||
scratch_cuda_cast_to_unsigned_64_ks32(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_input_blocks,
|
||||
target_num_blocks,
|
||||
input_is_signed,
|
||||
requires_full_propagate,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_cast_to_unsigned_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_output,
|
||||
&raw mut cuda_ffi_input,
|
||||
mem_ptr,
|
||||
target_num_blocks,
|
||||
input_is_signed,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
cuda_cast_to_unsigned_64_ks32(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_output,
|
||||
&raw mut cuda_ffi_input,
|
||||
mem_ptr,
|
||||
target_num_blocks,
|
||||
input_is_signed,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
} else if TypeId::of::<T>() == TypeId::of::<u64>() {
|
||||
scratch_cuda_cast_to_unsigned_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_input_blocks,
|
||||
target_num_blocks,
|
||||
input_is_signed,
|
||||
requires_full_propagate,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_cast_to_unsigned_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_output,
|
||||
&raw mut cuda_ffi_input,
|
||||
mem_ptr,
|
||||
target_num_blocks,
|
||||
input_is_signed,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
}
|
||||
|
||||
cleanup_cuda_cast_to_unsigned_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
|
||||
|
||||
@@ -408,57 +408,103 @@ impl CudaServerKey {
|
||||
is_unchecked: bool,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_boolean_bitnot_assign(
|
||||
streams,
|
||||
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
|
||||
is_unchecked,
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
d_bsk.output_lwe_dimension(),
|
||||
d_bsk.input_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
match &self.key_switching_key {
|
||||
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_boolean_bitnot_assign(
|
||||
streams,
|
||||
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
|
||||
is_unchecked,
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
d_bsk.output_lwe_dimension(),
|
||||
d_bsk.input_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_boolean_bitnot_assign(
|
||||
streams,
|
||||
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
|
||||
is_unchecked,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
d_multibit_bsk.output_lwe_dimension(),
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_boolean_bitnot_assign(
|
||||
streams,
|
||||
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
|
||||
is_unchecked,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
d_multibit_bsk.output_lwe_dimension(),
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
},
|
||||
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_boolean_bitnot_assign(
|
||||
streams,
|
||||
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
|
||||
is_unchecked,
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
d_bsk.output_lwe_dimension(),
|
||||
d_bsk.input_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_boolean_bitnot_assign(
|
||||
streams,
|
||||
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
|
||||
is_unchecked,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
d_multibit_bsk.output_lwe_dimension(),
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -534,61 +580,111 @@ impl CudaServerKey {
|
||||
|
||||
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_bitop_assign(
|
||||
streams,
|
||||
ct_left.as_mut(),
|
||||
ct_right.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
match &self.key_switching_key {
|
||||
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_bitop_assign(
|
||||
streams,
|
||||
ct_left.as_mut(),
|
||||
ct_right.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_bitop_assign(
|
||||
streams,
|
||||
ct_left.as_mut(),
|
||||
ct_right.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_bitop_assign(
|
||||
streams,
|
||||
ct_left.as_mut(),
|
||||
ct_right.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
},
|
||||
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_bitop_assign(
|
||||
streams,
|
||||
ct_left.as_mut(),
|
||||
ct_right.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_bitop_assign(
|
||||
streams,
|
||||
ct_left.as_mut(),
|
||||
ct_right.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,64 +20,117 @@ impl CudaServerKey {
|
||||
let mut result: T = self
|
||||
.create_trivial_zero_radix(true_ct.as_ref().d_blocks.lwe_ciphertext_count().0, stream);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_cmux(
|
||||
stream,
|
||||
result.as_mut(),
|
||||
condition,
|
||||
true_ct.as_ref(),
|
||||
false_ct.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
match &self.key_switching_key {
|
||||
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_cmux(
|
||||
stream,
|
||||
result.as_mut(),
|
||||
condition,
|
||||
true_ct.as_ref(),
|
||||
false_ct.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_cmux(
|
||||
stream,
|
||||
result.as_mut(),
|
||||
condition,
|
||||
true_ct.as_ref(),
|
||||
false_ct.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_cmux(
|
||||
stream,
|
||||
result.as_mut(),
|
||||
condition,
|
||||
true_ct.as_ref(),
|
||||
false_ct.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
},
|
||||
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_cmux(
|
||||
stream,
|
||||
result.as_mut(),
|
||||
condition,
|
||||
true_ct.as_ref(),
|
||||
false_ct.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_cmux(
|
||||
stream,
|
||||
result.as_mut(),
|
||||
condition,
|
||||
true_ct.as_ref(),
|
||||
false_ct.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
|
||||
@@ -45,65 +45,117 @@ impl CudaServerKey {
|
||||
let mut result =
|
||||
CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(block, ct_info));
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
match &self.key_switching_key {
|
||||
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_comparison(
|
||||
streams,
|
||||
result.as_mut().as_mut(),
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_comparison(
|
||||
streams,
|
||||
result.as_mut().as_mut(),
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_comparison(
|
||||
streams,
|
||||
result.as_mut().as_mut(),
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_comparison(
|
||||
streams,
|
||||
result.as_mut().as_mut(),
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_comparison(
|
||||
streams,
|
||||
result.as_mut().as_mut(),
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_comparison(
|
||||
streams,
|
||||
result.as_mut().as_mut(),
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
|
||||
@@ -1019,61 +1019,109 @@ impl CudaServerKey {
|
||||
self.create_trivial_zero_radix(target_num_blocks, streams);
|
||||
|
||||
let requires_full_propagate = !source.block_carries_are_empty();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_cast_to_unsigned(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
source.as_mut(),
|
||||
T::IS_SIGNED,
|
||||
requires_full_propagate,
|
||||
target_num_blocks as u32,
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
match &self.key_switching_key {
|
||||
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_cast_to_unsigned(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
source.as_mut(),
|
||||
T::IS_SIGNED,
|
||||
requires_full_propagate,
|
||||
target_num_blocks as u32,
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_cast_to_unsigned(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
source.as_mut(),
|
||||
T::IS_SIGNED,
|
||||
requires_full_propagate,
|
||||
target_num_blocks as u32,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_cast_to_unsigned(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
source.as_mut(),
|
||||
T::IS_SIGNED,
|
||||
requires_full_propagate,
|
||||
target_num_blocks as u32,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
},
|
||||
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_cast_to_unsigned(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
source.as_mut(),
|
||||
T::IS_SIGNED,
|
||||
requires_full_propagate,
|
||||
target_num_blocks as u32,
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_cast_to_unsigned(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
source.as_mut(),
|
||||
T::IS_SIGNED,
|
||||
requires_full_propagate,
|
||||
target_num_blocks as u32,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
|
||||
@@ -296,64 +296,118 @@ impl CudaServerKey {
|
||||
let aux_block: CudaUnsignedRadixCiphertext = self.create_trivial_zero_radix(1, stream);
|
||||
let in_carry_dvec =
|
||||
INPUT_BORROW.map_or_else(|| aux_block.as_ref(), |block| block.as_ref().as_ref());
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
|
||||
stream,
|
||||
ciphertext,
|
||||
rhs.as_ref(),
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
ciphertext.info.blocks.first().unwrap().carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
compute_overflow,
|
||||
uses_input_borrow,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
match &self.key_switching_key {
|
||||
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
|
||||
stream,
|
||||
ciphertext,
|
||||
rhs.as_ref(),
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
ciphertext.info.blocks.first().unwrap().carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
compute_overflow,
|
||||
uses_input_borrow,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
|
||||
stream,
|
||||
ciphertext,
|
||||
rhs.as_ref(),
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
ciphertext.info.blocks.first().unwrap().carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
compute_overflow,
|
||||
uses_input_borrow,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
|
||||
stream,
|
||||
ciphertext,
|
||||
rhs.as_ref(),
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
ciphertext.info.blocks.first().unwrap().carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
compute_overflow,
|
||||
uses_input_borrow,
|
||||
None,
|
||||
);
|
||||
},
|
||||
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
|
||||
stream,
|
||||
ciphertext,
|
||||
rhs.as_ref(),
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
ciphertext.info.blocks.first().unwrap().carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
compute_overflow,
|
||||
uses_input_borrow,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
|
||||
stream,
|
||||
ciphertext,
|
||||
rhs.as_ref(),
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
ciphertext.info.blocks.first().unwrap().carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
compute_overflow,
|
||||
uses_input_borrow,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(overflow_block.ciphertext);
|
||||
|
||||
(ct_res, ct_overflowed)
|
||||
|
||||
Reference in New Issue
Block a user