fix(gpu): erc20 classical ks32

This commit is contained in:
Andrei Stoian
2026-01-08 10:19:36 +01:00
parent da92c45fa6
commit b4d54b5ba2
26 changed files with 1755 additions and 556 deletions

View File

@@ -1044,7 +1044,7 @@ template <typename Torus, typename KSTorus> struct unsigned_int_div_rem_memory {
};
int_radix_lut<Torus, KSTorus> *luts[2] = {message_extract_lut_1,
message_extract_lut_2};
message_extract_lut_2};
auto active_streams =
streams.active_gpu_subset(num_blocks, params.pbs_type);
for (int j = 0; j < 2; j++) {
@@ -1134,7 +1134,8 @@ template <typename Torus, typename KSTorus> struct unsigned_int_div_rem_memory {
zero_out_if_overflow_happened[1]->broadcast_lut(active_streams);
// merge_overflow_flags_luts
merge_overflow_flags_luts = new int_radix_lut<Torus, KSTorus> *[num_bits_in_message];
merge_overflow_flags_luts =
new int_radix_lut<Torus, KSTorus> *[num_bits_in_message];
auto active_gpu_count_for_bits =
streams.active_gpu_subset(1, params.pbs_type);
for (int i = 0; i < num_bits_in_message; i++) {

View File

@@ -274,6 +274,16 @@ uint64_t scratch_cuda_comparison_64(
bool is_signed, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
uint64_t scratch_cuda_comparison_64_ks32(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, COMPARISON_TYPE op_type,
bool is_signed, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_comparison_ciphertext_64(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_1,
@@ -281,6 +291,12 @@ void cuda_comparison_ciphertext_64(CudaStreamsFFI streams,
int8_t *mem_ptr, void *const *bsks,
void *const *ksks);
void cuda_comparison_ciphertext_64_ks32(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_1,
CudaRadixCiphertextFFI const *lwe_array_2, int8_t *mem_ptr,
void *const *bsks, void *const *ksks);
void cuda_scalar_comparison_ciphertext_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in, void const *scalar_blocks,
@@ -323,6 +339,20 @@ void cuda_boolean_bitnot_ciphertext_64(CudaStreamsFFI streams,
int8_t *mem_ptr, void *const *bsks,
void *const *ksks);
uint64_t scratch_cuda_boolean_bitnot_64_ks32(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
uint32_t lwe_ciphertext_count, bool is_unchecked, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_boolean_bitnot_ciphertext_64_ks32(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *lwe_array,
int8_t *mem_ptr, void *const *bsks,
void *const *ksks);
void cleanup_cuda_boolean_bitnot(CudaStreamsFFI streams, int8_t **mem_ptr_void);
void cuda_bitnot_ciphertext_64(CudaStreamsFFI streams,
@@ -340,6 +370,15 @@ uint64_t scratch_cuda_bitop_64(
uint32_t carry_modulus, PBS_TYPE pbs_type, BITOP_TYPE op_type,
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
uint64_t scratch_cuda_bitop_64_ks32(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, BITOP_TYPE op_type,
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_scalar_bitop_ciphertext_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_input, void const *clear_blocks,
@@ -353,6 +392,13 @@ void cuda_bitop_ciphertext_64(CudaStreamsFFI streams,
int8_t *mem_ptr, void *const *bsks,
void *const *ksks);
void cuda_bitop_ciphertext_64_ks32(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_1,
CudaRadixCiphertextFFI const *lwe_array_2,
int8_t *mem_ptr, void *const *bsks,
void *const *ksks);
void cleanup_cuda_integer_bitop(CudaStreamsFFI streams, int8_t **mem_ptr_void);
uint64_t scratch_cuda_cmux_64(CudaStreamsFFI streams, int8_t **mem_ptr,
@@ -366,6 +412,15 @@ uint64_t scratch_cuda_cmux_64(CudaStreamsFFI streams, int8_t **mem_ptr,
PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
uint64_t scratch_cuda_cmux_64_ks32(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_cmux_ciphertext_64(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_condition,
@@ -374,6 +429,14 @@ void cuda_cmux_ciphertext_64(CudaStreamsFFI streams,
int8_t *mem_ptr, void *const *bsks,
void *const *ksks);
void cuda_cmux_ciphertext_64_ks32(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_condition,
CudaRadixCiphertextFFI const *lwe_array_true,
CudaRadixCiphertextFFI const *lwe_array_false,
int8_t *mem_ptr, void *const *bsks,
void *const *ksks);
void cleanup_cuda_cmux(CudaStreamsFFI streams, int8_t **mem_ptr_void);
uint64_t scratch_cuda_scalar_rotate_64(
@@ -452,6 +515,15 @@ uint64_t scratch_cuda_integer_overflowing_sub_64_inplace(
PBS_TYPE pbs_type, uint32_t compute_overflow, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
uint64_t scratch_cuda_integer_overflowing_sub_64_ks32_inplace(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, uint32_t compute_overflow, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_integer_overflowing_sub_64_inplace(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
const CudaRadixCiphertextFFI *rhs_array,
@@ -460,6 +532,14 @@ void cuda_integer_overflowing_sub_64_inplace(
void *const *bsks, void *const *ksks, uint32_t compute_overflow,
uint32_t uses_input_borrow);
void cuda_integer_overflowing_sub_64_ks32_inplace(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
const CudaRadixCiphertextFFI *rhs_array,
CudaRadixCiphertextFFI *overflow_block,
const CudaRadixCiphertextFFI *input_borrow, int8_t *mem_ptr,
void *const *bsks, void *const *ksks, uint32_t compute_overflow,
uint32_t uses_input_borrow);
void cleanup_cuda_integer_overflowing_sub(CudaStreamsFFI streams,
int8_t **mem_ptr_void);
@@ -844,12 +924,29 @@ uint64_t scratch_cuda_cast_to_unsigned_64(
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
uint64_t scratch_cuda_cast_to_unsigned_64_ks32(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed,
bool requires_full_propagate, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI *input, int8_t *mem_ptr,
uint32_t target_num_blocks, bool input_is_signed,
void *const *bsks, void *const *ksks);
void cuda_cast_to_unsigned_64_ks32(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI *input,
int8_t *mem_ptr, uint32_t target_num_blocks,
bool input_is_signed, void *const *bsks,
void *const *ksks);
void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void);

View File

@@ -34,7 +34,7 @@ __host__ void vectorized_aes_256_encrypt_inplace(
CudaStreams streams, CudaRadixCiphertextFFI *all_states_bitsliced,
CudaRadixCiphertextFFI const *round_keys, uint32_t num_aes_inputs,
int_aes_encrypt_buffer<Torus, KSTorus> *mem, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
constexpr uint32_t BITS_PER_BYTE = 8;
constexpr uint32_t STATE_BYTES = 16;
@@ -186,7 +186,7 @@ __host__ void host_integer_aes_ctr_256_encrypt(
CudaRadixCiphertextFFI const *iv, CudaRadixCiphertextFFI const *round_keys,
const Torus *counter_bits_le_all_blocks, uint32_t num_aes_inputs,
int_aes_encrypt_buffer<Torus, KSTorus> *mem, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
constexpr uint32_t NUM_BITS = 128;
@@ -245,7 +245,7 @@ __host__ void host_integer_key_expansion_256(
CudaStreams streams, CudaRadixCiphertextFFI *expanded_keys,
CudaRadixCiphertextFFI const *key,
int_key_expansion_256_buffer<Torus, KSTorus> *mem, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
constexpr uint32_t BITS_PER_WORD = 32;
constexpr uint32_t BITS_PER_BYTE = 8;

View File

@@ -23,10 +23,10 @@ __host__ uint64_t scratch_cuda_integer_abs(
return size_tracker;
}
template <typename Torus>
template <typename Torus, typename KSTorus>
__host__ void host_integer_abs(CudaStreams streams, CudaRadixCiphertextFFI *ct,
void *const *bsks, uint64_t *const *ksks,
int_abs_buffer<uint64_t, uint64_t> *mem_ptr,
void *const *bsks, KSTorus *const *ksks,
int_abs_buffer<Torus, KSTorus> *mem_ptr,
bool is_signed) {
if (!is_signed)
return;
@@ -40,7 +40,7 @@ __host__ void host_integer_abs(CudaStreams streams, CudaRadixCiphertextFFI *ct,
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
mask, ct);
host_arithmetic_scalar_shift_inplace<Torus>(
host_arithmetic_scalar_shift_inplace<Torus, KSTorus>(
streams, mask, num_bits_in_ciphertext - 1,
mem_ptr->arithmetic_scalar_shift_mem, bsks, ksks);
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), ct, mask, ct,
@@ -49,11 +49,12 @@ __host__ void host_integer_abs(CudaStreams streams, CudaRadixCiphertextFFI *ct,
uint32_t requested_flag = outputFlag::FLAG_NONE;
uint32_t uses_carry = 0;
host_propagate_single_carry<Torus>(streams, ct, nullptr, nullptr,
mem_ptr->scp_mem, bsks, ksks,
requested_flag, uses_carry);
host_propagate_single_carry<Torus, KSTorus>(streams, ct, nullptr, nullptr,
mem_ptr->scp_mem, bsks, ksks,
requested_flag, uses_carry);
host_bitop<Torus>(streams, ct, mask, ct, mem_ptr->bitxor_mem, bsks, ksks);
host_bitop<Torus, KSTorus>(streams, ct, mask, ct, mem_ptr->bitxor_mem, bsks,
ksks);
}
#endif // TFHE_RS_ABS_CUH

View File

@@ -63,6 +63,26 @@ uint64_t scratch_cuda_boolean_bitnot_64(
lwe_ciphertext_count, is_unchecked, allocate_gpu_memory);
}
uint64_t scratch_cuda_boolean_bitnot_64_ks32(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
uint32_t lwe_ciphertext_count, bool is_unchecked, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
return scratch_cuda_boolean_bitnot<uint64_t, uint32_t>(
CudaStreams(streams),
(boolean_bitnot_buffer<uint64_t, uint32_t> **)mem_ptr, params,
lwe_ciphertext_count, is_unchecked, allocate_gpu_memory);
}
void cuda_boolean_bitnot_ciphertext_64(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *lwe_array,
int8_t *mem_ptr, void *const *bsks,
@@ -73,6 +93,16 @@ void cuda_boolean_bitnot_ciphertext_64(CudaStreamsFFI streams,
(uint64_t **)(ksks));
}
void cuda_boolean_bitnot_ciphertext_64_ks32(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *lwe_array,
int8_t *mem_ptr, void *const *bsks,
void *const *ksks) {
host_boolean_bitnot<uint64_t, uint32_t>(
CudaStreams(streams), lwe_array,
(boolean_bitnot_buffer<uint64_t, uint32_t> *)mem_ptr, bsks,
(uint32_t **)(ksks));
}
void cleanup_cuda_boolean_bitnot(CudaStreamsFFI streams,
int8_t **mem_ptr_void) {
@@ -102,6 +132,25 @@ uint64_t scratch_cuda_bitop_64(
lwe_ciphertext_count, params, op_type, allocate_gpu_memory);
}
uint64_t scratch_cuda_bitop_64_ks32(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, BITOP_TYPE op_type,
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
return scratch_cuda_bitop<uint64_t, uint32_t>(
CudaStreams(streams), (int_bitop_buffer<uint64_t, uint32_t> **)mem_ptr,
lwe_ciphertext_count, params, op_type, allocate_gpu_memory);
}
void cuda_bitnot_ciphertext_64(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *radix_ciphertext,
uint32_t ct_message_modulus,
@@ -126,6 +175,19 @@ void cuda_bitop_ciphertext_64(CudaStreamsFFI streams,
(uint64_t **)(ksks));
}
void cuda_bitop_ciphertext_64_ks32(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_1,
CudaRadixCiphertextFFI const *lwe_array_2,
int8_t *mem_ptr, void *const *bsks,
void *const *ksks) {
host_bitop<uint64_t>(CudaStreams(streams), lwe_array_out, lwe_array_1,
lwe_array_2,
(int_bitop_buffer<uint64_t, uint32_t> *)mem_ptr, bsks,
(uint32_t **)(ksks));
}
void cleanup_cuda_integer_bitop(CudaStreamsFFI streams, int8_t **mem_ptr_void) {
int_bitop_buffer<uint64_t, uint64_t> *mem_ptr =

View File

@@ -49,6 +49,28 @@ uint64_t scratch_cuda_cast_to_unsigned_64(
requires_full_propagate, allocate_gpu_memory);
}
uint64_t scratch_cuda_cast_to_unsigned_64_ks32(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed,
bool requires_full_propagate, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
return scratch_cuda_cast_to_unsigned<uint64_t, uint32_t>(
CudaStreams(streams),
(int_cast_to_unsigned_buffer<uint64_t, uint32_t> **)mem_ptr, params,
num_input_blocks, target_num_blocks, input_is_signed,
requires_full_propagate, allocate_gpu_memory);
}
void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI *input, int8_t *mem_ptr,
@@ -61,6 +83,19 @@ void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
target_num_blocks, input_is_signed, bsks, (uint64_t **)ksks);
}
void cuda_cast_to_unsigned_64_ks32(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI *input,
int8_t *mem_ptr, uint32_t target_num_blocks,
bool input_is_signed, void *const *bsks,
void *const *ksks) {
host_cast_to_unsigned<uint64_t, uint32_t>(
CudaStreams(streams), output, input,
(int_cast_to_unsigned_buffer<uint64_t, uint32_t> *)mem_ptr,
target_num_blocks, input_is_signed, bsks, (uint32_t **)ksks);
}
void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void) {
int_cast_to_unsigned_buffer<uint64_t, uint64_t> *mem_ptr =

View File

@@ -130,7 +130,7 @@ host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI *input,
int_cast_to_unsigned_buffer<Torus, KSTorus> *mem_ptr,
uint32_t target_num_blocks, bool input_is_signed,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
uint32_t current_num_blocks = input->num_radix_blocks;
@@ -143,9 +143,9 @@ host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output,
uint32_t num_blocks_to_add = target_num_blocks - current_num_blocks;
if (input_is_signed) {
host_extend_radix_with_sign_msb<Torus>(
host_extend_radix_with_sign_msb<Torus, KSTorus>(
streams, output, input, mem_ptr->extend_buffer, num_blocks_to_add,
bsks, (Torus **)ksks);
bsks, ksks);
} else {
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(output, input,
streams);

View File

@@ -26,6 +26,30 @@ uint64_t scratch_cuda_cmux_64(CudaStreamsFFI streams, int8_t **mem_ptr,
return ret;
}
uint64_t scratch_cuda_cmux_64_ks32(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type) {
PUSH_RANGE("scratch cmux")
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
std::function<uint64_t(uint64_t)> predicate_lut_f =
[](uint64_t x) -> uint64_t { return x == 1; };
uint64_t ret = scratch_cuda_cmux<uint64_t, uint32_t>(
CudaStreams(streams), (int_cmux_buffer<uint64_t, uint32_t> **)mem_ptr,
predicate_lut_f, lwe_ciphertext_count, params, allocate_gpu_memory);
POP_RANGE()
return ret;
}
void cuda_cmux_ciphertext_64(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_condition,
@@ -34,10 +58,25 @@ void cuda_cmux_ciphertext_64(CudaStreamsFFI streams,
int8_t *mem_ptr, void *const *bsks,
void *const *ksks) {
PUSH_RANGE("cmux")
host_cmux<uint64_t>(CudaStreams(streams), lwe_array_out, lwe_condition,
lwe_array_true, lwe_array_false,
(int_cmux_buffer<uint64_t, uint64_t> *)mem_ptr, bsks,
(uint64_t **)(ksks));
host_cmux<uint64_t, uint64_t>(CudaStreams(streams), lwe_array_out,
lwe_condition, lwe_array_true, lwe_array_false,
(int_cmux_buffer<uint64_t, uint64_t> *)mem_ptr,
bsks, (uint64_t **)(ksks));
POP_RANGE()
}
void cuda_cmux_ciphertext_64_ks32(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_condition,
CudaRadixCiphertextFFI const *lwe_array_true,
CudaRadixCiphertextFFI const *lwe_array_false,
int8_t *mem_ptr, void *const *bsks,
void *const *ksks) {
PUSH_RANGE("cmux")
host_cmux<uint64_t, uint32_t>(CudaStreams(streams), lwe_array_out,
lwe_condition, lwe_array_true, lwe_array_false,
(int_cmux_buffer<uint64_t, uint32_t> *)mem_ptr,
bsks, (uint32_t **)(ksks));
POP_RANGE()
}

View File

@@ -39,6 +39,45 @@ uint64_t scratch_cuda_comparison_64(
return size_tracker;
}
uint64_t scratch_cuda_comparison_64_ks32(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_radix_blocks, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, COMPARISON_TYPE op_type, bool is_signed,
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) {
PUSH_RANGE("scratch comparison")
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
uint64_t size_tracker = 0;
switch (op_type) {
case EQ:
case NE:
size_tracker += scratch_cuda_comparison_check<uint64_t, uint32_t>(
CudaStreams(streams),
(int_comparison_buffer<uint64_t, uint32_t> **)mem_ptr, num_radix_blocks,
params, op_type, false, allocate_gpu_memory);
break;
case GT:
case GE:
case LT:
case LE:
case MAX:
case MIN:
size_tracker += scratch_cuda_comparison_check<uint64_t, uint32_t>(
CudaStreams(streams),
(int_comparison_buffer<uint64_t, uint32_t> **)mem_ptr, num_radix_blocks,
params, op_type, is_signed, allocate_gpu_memory);
break;
}
POP_RANGE()
return size_tracker;
}
void cuda_comparison_ciphertext_64(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_1,
@@ -87,6 +126,53 @@ void cuda_comparison_ciphertext_64(CudaStreamsFFI streams,
POP_RANGE()
}
void cuda_comparison_ciphertext_64_ks32(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_1,
CudaRadixCiphertextFFI const *lwe_array_2, int8_t *mem_ptr,
void *const *bsks, void *const *ksks) {
PUSH_RANGE("comparison")
if (lwe_array_1->num_radix_blocks != lwe_array_2->num_radix_blocks)
PANIC("Cuda error: input num radix blocks must be the same")
// The output ciphertext might be a boolean block or a radix ciphertext
// depending on the case (eq/gt vs max/min) so the amount of blocks to
// consider for calculation is the one of the input
auto num_radix_blocks = lwe_array_1->num_radix_blocks;
int_comparison_buffer<uint64_t, uint32_t> *buffer =
(int_comparison_buffer<uint64_t, uint32_t> *)mem_ptr;
switch (buffer->op) {
case EQ:
case NE:
host_equality_check<uint64_t>(CudaStreams(streams), lwe_array_out,
lwe_array_1, lwe_array_2, buffer, bsks,
(uint32_t **)(ksks), num_radix_blocks);
break;
case GT:
case GE:
case LT:
case LE:
if (num_radix_blocks % 2 != 0)
PANIC("Cuda error (comparisons): the number of radix blocks has to be "
"even.")
host_difference_check<uint64_t>(CudaStreams(streams), lwe_array_out,
lwe_array_1, lwe_array_2, buffer,
buffer->diff_buffer->operator_f, bsks,
(uint32_t **)(ksks), num_radix_blocks);
break;
case MAX:
case MIN:
if (num_radix_blocks % 2 != 0)
PANIC("Cuda error (max/min): the number of radix blocks has to be even.")
host_maxmin<uint64_t>(CudaStreams(streams), lwe_array_out, lwe_array_1,
lwe_array_2, buffer, bsks, (uint32_t **)(ksks),
num_radix_blocks);
break;
default:
PANIC("Cuda error: integer operation not supported")
}
POP_RANGE()
}
void cleanup_cuda_integer_comparison(CudaStreamsFFI streams,
int8_t **mem_ptr_void) {
PUSH_RANGE("cleanup comparison")

View File

@@ -21,6 +21,27 @@ uint64_t scratch_cuda_integer_div_rem_radix_ciphertext_64(
POP_RANGE()
}
uint64_t scratch_cuda_integer_div_rem_radix_ciphertext_64_ks32(
CudaStreamsFFI streams, bool is_signed, int8_t **mem_ptr,
uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t big_lwe_dimension, uint32_t small_lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t num_blocks, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type) {
PUSH_RANGE("scratch div")
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
return scratch_cuda_integer_div_rem<uint64_t, uint32_t>(
CudaStreams(streams), is_signed,
(int_div_rem_memory<uint64_t, uint32_t> **)mem_ptr, num_blocks, params,
allocate_gpu_memory);
POP_RANGE()
}
void cuda_integer_div_rem_radix_ciphertext_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *quotient,
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
@@ -35,6 +56,20 @@ void cuda_integer_div_rem_radix_ciphertext_64(
POP_RANGE()
}
void cuda_integer_div_rem_radix_ciphertext_64_ks32(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *quotient,
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
CudaRadixCiphertextFFI const *divisor, bool is_signed, int8_t *mem_ptr,
void *const *bsks, void *const *ksks) {
PUSH_RANGE("div")
auto mem = (int_div_rem_memory<uint64_t, uint32_t> *)mem_ptr;
host_integer_div_rem<uint64_t, uint32_t>(
CudaStreams(streams), quotient, remainder, numerator, divisor, is_signed,
bsks, (uint32_t **)(ksks), mem);
POP_RANGE()
}
void cleanup_cuda_integer_div_rem(CudaStreamsFFI streams,
int8_t **mem_ptr_void) {
PUSH_RANGE("cleanup div")

View File

@@ -31,8 +31,8 @@ __host__ void host_unsigned_integer_div_rem_block_by_block_2_2(
CudaStreams streams, CudaRadixCiphertextFFI *quotient,
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
CudaRadixCiphertextFFI const *divisor, void *const *bsks,
uint64_t *const *ksks,
unsigned_int_div_rem_2_2_memory<uint64_t, uint64_t> *mem_ptr) {
KSTorus *const *ksks,
unsigned_int_div_rem_2_2_memory<Torus, KSTorus> *mem_ptr) {
if (streams.count() < 4) {
PANIC("GPU count should be greater than 4 when using div_rem_2_2");
@@ -476,8 +476,8 @@ __host__ void host_unsigned_integer_div_rem(
CudaStreams streams, CudaRadixCiphertextFFI *quotient,
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
CudaRadixCiphertextFFI const *divisor, void *const *bsks,
uint64_t *const *ksks,
unsigned_int_div_rem_memory<uint64_t, uint64_t> *mem_ptr) {
KSTorus *const *ksks,
unsigned_int_div_rem_memory<Torus, KSTorus> *mem_ptr) {
if (remainder->num_radix_blocks != numerator->num_radix_blocks ||
remainder->num_radix_blocks != divisor->num_radix_blocks ||
@@ -740,7 +740,7 @@ __host__ void host_unsigned_integer_div_rem(
mem_ptr->overflow_sub_mem->update_lut_indexes(
streams, first_indexes, second_indexes, scalar_indexes,
merged_interesting_remainder->num_radix_blocks);
host_integer_overflowing_sub<uint64_t>(
host_integer_overflowing_sub<Torus, KSTorus>(
streams, new_remainder, merged_interesting_remainder,
interesting_divisor, subtraction_overflowed,
(const CudaRadixCiphertextFFI *)nullptr, mem_ptr->overflow_sub_mem,
@@ -903,13 +903,11 @@ __host__ void host_unsigned_integer_div_rem(
}
template <typename Torus, typename KSTorus>
__host__ void
host_integer_div_rem(CudaStreams streams, CudaRadixCiphertextFFI *quotient,
CudaRadixCiphertextFFI *remainder,
CudaRadixCiphertextFFI const *numerator,
CudaRadixCiphertextFFI const *divisor, bool is_signed,
void *const *bsks, uint64_t *const *ksks,
int_div_rem_memory<uint64_t, uint64_t> *int_mem_ptr) {
__host__ void host_integer_div_rem(
CudaStreams streams, CudaRadixCiphertextFFI *quotient,
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
CudaRadixCiphertextFFI const *divisor, bool is_signed, void *const *bsks,
KSTorus *const *ksks, int_div_rem_memory<Torus, KSTorus> *int_mem_ptr) {
if (remainder->num_radix_blocks != numerator->num_radix_blocks ||
remainder->num_radix_blocks != divisor->num_radix_blocks ||
remainder->num_radix_blocks != quotient->num_radix_blocks)

View File

@@ -116,6 +116,25 @@ uint64_t scratch_cuda_integer_overflowing_sub_64_inplace(
params, compute_overflow, allocate_gpu_memory);
}
uint64_t scratch_cuda_integer_overflowing_sub_64_ks32_inplace(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, uint32_t compute_overflow, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
return scratch_cuda_integer_overflowing_sub<uint64_t, uint32_t>(
CudaStreams(streams),
(int_borrow_prop_memory<uint64_t, uint32_t> **)mem_ptr, num_blocks,
params, compute_overflow, allocate_gpu_memory);
}
void cuda_propagate_single_carry_64_inplace(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array,
CudaRadixCiphertextFFI *carry_out, const CudaRadixCiphertextFFI *carry_in,
@@ -128,6 +147,18 @@ void cuda_propagate_single_carry_64_inplace(
(uint64_t **)(ksks), requested_flag, uses_carry);
}
void cuda_propagate_single_carry_64_ks32_inplace(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array,
CudaRadixCiphertextFFI *carry_out, const CudaRadixCiphertextFFI *carry_in,
int8_t *mem_ptr, void *const *bsks, void *const *ksks,
uint32_t requested_flag, uint32_t uses_carry) {
host_propagate_single_carry<uint64_t, uint32_t>(
CudaStreams(streams), lwe_array, carry_out, carry_in,
(int_sc_prop_memory<uint64_t, uint32_t> *)mem_ptr, bsks,
(uint32_t **)(ksks), requested_flag, uses_carry);
}
void cuda_add_and_propagate_single_carry_64_inplace(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out,
@@ -146,10 +177,10 @@ void cuda_add_and_propagate_single_carry_64_ks32_inplace(
const CudaRadixCiphertextFFI *carry_in, int8_t *mem_ptr, void *const *bsks,
void *const *ksks, uint32_t requested_flag, uint32_t uses_carry) {
host_add_and_propagate_single_carry<uint64_t, uint32_t>(
CudaStreams(streams), lhs_array, rhs_array, carry_out, carry_in,
(int_sc_prop_memory<uint64_t, uint32_t> *)mem_ptr, bsks,
(uint32_t **)(ksks), requested_flag, uses_carry);
host_add_and_propagate_single_carry<uint64_t, uint32_t>(
CudaStreams(streams), lhs_array, rhs_array, carry_out, carry_in,
(int_sc_prop_memory<uint64_t, uint32_t> *)mem_ptr, bsks,
(uint32_t **)(ksks), requested_flag, uses_carry);
}
void cuda_integer_overflowing_sub_64_inplace(
@@ -167,6 +198,21 @@ void cuda_integer_overflowing_sub_64_inplace(
POP_RANGE()
}
void cuda_integer_overflowing_sub_64_ks32_inplace(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
const CudaRadixCiphertextFFI *rhs_array,
CudaRadixCiphertextFFI *overflow_block,
const CudaRadixCiphertextFFI *input_borrow, int8_t *mem_ptr,
void *const *bsks, void *const *ksks, uint32_t compute_overflow,
uint32_t uses_input_borrow) {
PUSH_RANGE("overflow sub")
host_integer_overflowing_sub<uint64_t, uint32_t>(
CudaStreams(streams), lhs_array, lhs_array, rhs_array, overflow_block,
input_borrow, (int_borrow_prop_memory<uint64_t, uint32_t> *)mem_ptr, bsks,
(uint32_t **)ksks, compute_overflow, uses_input_borrow);
POP_RANGE()
}
void cleanup_cuda_propagate_single_carry(CudaStreamsFFI streams,
int8_t **mem_ptr_void) {
PUSH_RANGE("cleanup propagate sc")

View File

@@ -115,7 +115,7 @@ void host_integer_grouped_oprf_custom_range(
const Torus *decomposed_scalar, const Torus *has_at_least_one_set,
uint32_t num_scalars, uint32_t shift,
int_grouped_oprf_custom_range_memory<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
CudaRadixCiphertextFFI *computation_buffer = mem_ptr->tmp_oprf_output;
set_zero_radix_ciphertext_slice_async<Torus>(

View File

@@ -20,6 +20,26 @@ uint64_t scratch_cuda_sub_and_propagate_single_carry_64_inplace(
requested_flag, allocate_gpu_memory);
}
uint64_t scratch_cuda_sub_and_propagate_single_carry_64_ks32_inplace(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, uint32_t requested_flag, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
return scratch_cuda_sub_and_propagate_single_carry<uint64_t, uint32_t>(
CudaStreams(streams),
(int_sub_and_propagate<uint64_t, uint32_t> **)mem_ptr, num_blocks, params,
requested_flag, allocate_gpu_memory);
}
void cuda_sub_and_propagate_single_carry_64_inplace(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out,
@@ -33,6 +53,19 @@ void cuda_sub_and_propagate_single_carry_64_inplace(
POP_RANGE()
}
void cuda_sub_and_propagate_single_carry_64_ks32_inplace(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array,
const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out,
const CudaRadixCiphertextFFI *carry_in, int8_t *mem_ptr, void *const *bsks,
void *const *ksks, uint32_t requested_flag, uint32_t uses_carry) {
PUSH_RANGE("sub")
host_sub_and_propagate_single_carry<uint64_t, uint32_t>(
CudaStreams(streams), lhs_array, rhs_array, carry_out, carry_in,
(int_sub_and_propagate<uint64_t, uint32_t> *)mem_ptr, bsks,
(uint32_t **)(ksks), requested_flag, uses_carry);
POP_RANGE()
}
void cleanup_cuda_sub_and_propagate_single_carry(CudaStreamsFFI streams,
int8_t **mem_ptr_void) {
PUSH_RANGE("cleanup sub")

View File

@@ -95,7 +95,8 @@ __host__ void host_integer_overflowing_sub(
CudaRadixCiphertextFFI *overflow_block,
const CudaRadixCiphertextFFI *input_borrow,
int_borrow_prop_memory<Torus, KSTorus> *mem_ptr, void *const *bsks,
Torus *const *ksks, uint32_t compute_overflow, uint32_t uses_input_borrow) {
KSTorus *const *ksks, uint32_t compute_overflow,
uint32_t uses_input_borrow) {
PUSH_RANGE("overflowing sub")
if (output->num_radix_blocks != input_left->num_radix_blocks ||
output->num_radix_blocks != input_right->num_radix_blocks)
@@ -124,8 +125,8 @@ __host__ void host_integer_overflowing_sub(
host_single_borrow_propagate<Torus>(
streams, output, overflow_block, input_borrow,
(int_borrow_prop_memory<Torus, KSTorus> *)mem_ptr, bsks, (Torus **)(ksks),
num_groups, compute_overflow, uses_input_borrow);
(int_borrow_prop_memory<Torus, KSTorus> *)mem_ptr, bsks,
(KSTorus **)(ksks), num_groups, compute_overflow, uses_input_borrow);
POP_RANGE()
}

View File

@@ -25,7 +25,7 @@ __host__ void host_unchecked_all_eq_slices(
CudaRadixCiphertextFFI const *lhs, CudaRadixCiphertextFFI const *rhs,
uint32_t num_inputs, uint32_t num_blocks,
int_unchecked_all_eq_slices_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
// sync_from(streams)
//
@@ -104,7 +104,7 @@ __host__ void host_unchecked_contains_sub_slice(
CudaRadixCiphertextFFI const *lhs, CudaRadixCiphertextFFI const *rhs,
uint32_t num_rhs, uint32_t num_blocks,
int_unchecked_contains_sub_slice_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
uint32_t num_windows = mem_ptr->num_windows;

View File

@@ -14,7 +14,7 @@ __host__ void host_compute_equality_selectors(
CudaRadixCiphertextFFI const *lwe_array_in, uint32_t num_blocks,
const uint64_t *h_decomposed_cleartexts,
int_equality_selectors_buffer<Torus, KSTorus> *mem_ptr, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
uint32_t num_possible_values = mem_ptr->num_possible_values;
uint32_t message_modulus = mem_ptr->params.message_modulus;
@@ -91,7 +91,7 @@ __host__ void host_create_possible_results(
CudaRadixCiphertextFFI const *lwe_array_in_list,
uint32_t num_possible_values, const uint64_t *h_decomposed_cleartexts,
uint32_t num_blocks, int_possible_results_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
uint32_t max_packed_value = mem_ptr->max_packed_value;
uint32_t max_luts_per_call = mem_ptr->max_luts_per_call;
@@ -174,7 +174,7 @@ __host__ void host_aggregate_one_hot_vector(
CudaRadixCiphertextFFI const *lwe_array_in_list,
uint32_t num_input_ciphertexts, uint32_t num_blocks,
int_aggregate_one_hot_buffer<Torus, KSTorus> *mem_ptr, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
int_radix_params params = mem_ptr->params;
uint32_t chunk_size = mem_ptr->chunk_size;
@@ -345,7 +345,7 @@ __host__ void host_unchecked_match_value(
CudaRadixCiphertextFFI const *lwe_array_in_ct,
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
int_unchecked_match_buffer<Torus, KSTorus> *mem_ptr, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
host_compute_equality_selectors<Torus>(
streams, mem_ptr->selectors_list, lwe_array_in_ct,
mem_ptr->num_input_blocks, h_match_inputs, mem_ptr->eq_selectors_buffer,
@@ -422,7 +422,7 @@ __host__ void host_unchecked_match_value_or(
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
const uint64_t *h_or_value,
int_unchecked_match_value_or_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
host_unchecked_match_value<Torus>(streams, mem_ptr->tmp_match_result,
mem_ptr->tmp_match_bool, lwe_array_in_ct,
@@ -465,7 +465,7 @@ host_unchecked_contains(CudaStreams streams, CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI const *value,
uint32_t num_inputs, uint32_t num_blocks,
int_unchecked_contains_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
mem_ptr->internal_cuda_streams.internal_streams_wait_for_main_stream_0(
streams);
@@ -516,7 +516,7 @@ __host__ void host_unchecked_contains_clear(
CudaRadixCiphertextFFI const *inputs, const uint64_t *h_clear_val,
uint32_t num_inputs, uint32_t num_blocks,
int_unchecked_contains_clear_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
cuda_memcpy_async_to_gpu(mem_ptr->d_clear_val, h_clear_val,
num_blocks * sizeof(Torus), streams.stream(0),
@@ -577,7 +577,7 @@ __host__ void host_unchecked_is_in_clears(
CudaRadixCiphertextFFI const *input, const uint64_t *h_cleartexts,
uint32_t num_clears, uint32_t num_blocks,
int_unchecked_is_in_clears_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
host_compute_equality_selectors<Torus>(streams, mem_ptr->unpacked_selectors,
input, num_blocks, h_cleartexts,
@@ -594,7 +594,7 @@ __host__ void host_compute_final_index_from_selectors(
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *selectors,
uint32_t num_inputs, uint32_t num_blocks_index,
int_final_index_from_selectors_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
for (uint32_t i = 0; i < num_inputs; i++) {
CudaRadixCiphertextFFI const *src_selector = &selectors[i];
@@ -657,7 +657,7 @@ __host__ void host_unchecked_index_in_clears(
const uint64_t *h_cleartexts, uint32_t num_clears, uint32_t num_blocks,
uint32_t num_blocks_index,
int_unchecked_index_in_clears_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
host_compute_equality_selectors<Torus>(
streams, mem_ptr->final_index_buf->unpacked_selectors, input, num_blocks,
@@ -704,7 +704,7 @@ __host__ void host_unchecked_first_index_in_clears(
const uint64_t *h_unique_values, const uint64_t *h_unique_indices,
uint32_t num_unique, uint32_t num_blocks, uint32_t num_blocks_index,
int_unchecked_first_index_in_clears_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
host_compute_equality_selectors<Torus>(streams, mem_ptr->unpacked_selectors,
input, num_blocks, h_unique_values,
@@ -748,7 +748,7 @@ __host__ void host_unchecked_first_index_of_clear(
const uint64_t *h_clear_val, uint32_t num_inputs, uint32_t num_blocks,
uint32_t num_blocks_index,
int_unchecked_first_index_of_clear_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
cuda_memcpy_async_to_gpu(mem_ptr->d_clear_val, h_clear_val,
num_blocks * sizeof(Torus), streams.stream(0),
@@ -841,7 +841,7 @@ __host__ void host_unchecked_first_index_of(
CudaRadixCiphertextFFI const *value, uint32_t num_inputs,
uint32_t num_blocks, uint32_t num_blocks_index,
int_unchecked_first_index_of_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
mem_ptr->internal_cuda_streams.internal_streams_wait_for_main_stream_0(
streams);
@@ -924,7 +924,7 @@ __host__ void host_unchecked_index_of(
CudaRadixCiphertextFFI const *value, uint32_t num_inputs,
uint32_t num_blocks, uint32_t num_blocks_index,
int_unchecked_index_of_buffer<Torus, KSTorus> *mem_ptr, void *const *bsks,
Torus *const *ksks) {
KSTorus *const *ksks) {
mem_ptr->internal_cuda_streams.internal_streams_wait_for_main_stream_0(
streams);
@@ -992,7 +992,7 @@ __host__ void host_unchecked_index_of_clear(
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks,
uint32_t num_blocks_index,
int_unchecked_index_of_clear_buffer<Torus, KSTorus> *mem_ptr,
void *const *bsks, Torus *const *ksks) {
void *const *bsks, KSTorus *const *ksks) {
CudaRadixCiphertextFFI *packed_selectors =
mem_ptr->final_index_buf->packed_selectors;

View File

@@ -651,6 +651,29 @@ unsafe extern "C" {
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn scratch_cuda_comparison_64_ks32(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
big_lwe_dimension: u32,
small_lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
lwe_ciphertext_count: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
op_type: COMPARISON_TYPE,
is_signed: bool,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_comparison_ciphertext_64(
streams: CudaStreamsFFI,
@@ -662,6 +685,17 @@ unsafe extern "C" {
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cuda_comparison_ciphertext_64_ks32(
streams: CudaStreamsFFI,
lwe_array_out: *mut CudaRadixCiphertextFFI,
lwe_array_1: *const CudaRadixCiphertextFFI,
lwe_array_2: *const CudaRadixCiphertextFFI,
mem_ptr: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cuda_scalar_comparison_ciphertext_64(
streams: CudaStreamsFFI,
@@ -746,6 +780,37 @@ unsafe extern "C" {
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn scratch_cuda_boolean_bitnot_64_ks32(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
big_lwe_dimension: u32,
small_lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
lwe_ciphertext_count: u32,
is_unchecked: bool,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_boolean_bitnot_ciphertext_64_ks32(
streams: CudaStreamsFFI,
lwe_array: *mut CudaRadixCiphertextFFI,
mem_ptr: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_boolean_bitnot(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
}
@@ -780,6 +845,28 @@ unsafe extern "C" {
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn scratch_cuda_bitop_64_ks32(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
big_lwe_dimension: u32,
small_lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
lwe_ciphertext_count: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
op_type: BITOP_TYPE,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_scalar_bitop_ciphertext_64(
streams: CudaStreamsFFI,
@@ -804,6 +891,17 @@ unsafe extern "C" {
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cuda_bitop_ciphertext_64_ks32(
streams: CudaStreamsFFI,
lwe_array_out: *mut CudaRadixCiphertextFFI,
lwe_array_1: *const CudaRadixCiphertextFFI,
lwe_array_2: *const CudaRadixCiphertextFFI,
mem_ptr: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_integer_bitop(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
}
@@ -828,6 +926,27 @@ unsafe extern "C" {
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn scratch_cuda_cmux_64_ks32(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
big_lwe_dimension: u32,
small_lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
lwe_ciphertext_count: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_cmux_ciphertext_64(
streams: CudaStreamsFFI,
@@ -840,6 +959,18 @@ unsafe extern "C" {
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cuda_cmux_ciphertext_64_ks32(
streams: CudaStreamsFFI,
lwe_array_out: *mut CudaRadixCiphertextFFI,
lwe_condition: *const CudaRadixCiphertextFFI,
lwe_array_true: *const CudaRadixCiphertextFFI,
lwe_array_false: *const CudaRadixCiphertextFFI,
mem_ptr: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_cmux(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
}
@@ -1016,6 +1147,28 @@ unsafe extern "C" {
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn scratch_cuda_integer_overflowing_sub_64_ks32_inplace(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
big_lwe_dimension: u32,
small_lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
num_blocks: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
compute_overflow: u32,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_integer_overflowing_sub_64_inplace(
streams: CudaStreamsFFI,
@@ -1030,6 +1183,20 @@ unsafe extern "C" {
uses_input_borrow: u32,
);
}
unsafe extern "C" {
pub fn cuda_integer_overflowing_sub_64_ks32_inplace(
streams: CudaStreamsFFI,
lhs_array: *mut CudaRadixCiphertextFFI,
rhs_array: *const CudaRadixCiphertextFFI,
overflow_block: *mut CudaRadixCiphertextFFI,
input_borrow: *const CudaRadixCiphertextFFI,
mem_ptr: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
compute_overflow: u32,
uses_input_borrow: u32,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_integer_overflowing_sub(
streams: CudaStreamsFFI,
@@ -1169,6 +1336,28 @@ unsafe extern "C" {
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn scratch_cuda_integer_div_rem_radix_ciphertext_64_ks32(
streams: CudaStreamsFFI,
is_signed: bool,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
big_lwe_dimension: u32,
small_lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
num_blocks: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_integer_div_rem_radix_ciphertext_64(
streams: CudaStreamsFFI,
@@ -1182,6 +1371,19 @@ unsafe extern "C" {
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cuda_integer_div_rem_radix_ciphertext_64_ks32(
streams: CudaStreamsFFI,
quotient: *mut CudaRadixCiphertextFFI,
remainder: *mut CudaRadixCiphertextFFI,
numerator: *const CudaRadixCiphertextFFI,
divisor: *const CudaRadixCiphertextFFI,
is_signed: bool,
mem_ptr: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_integer_div_rem(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
}
@@ -1826,6 +2028,30 @@ unsafe extern "C" {
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn scratch_cuda_cast_to_unsigned_64_ks32(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
big_lwe_dimension: u32,
small_lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
num_input_blocks: u32,
target_num_blocks: u32,
input_is_signed: bool,
requires_full_propagate: bool,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_cast_to_unsigned_64(
streams: CudaStreamsFFI,
@@ -1838,6 +2064,18 @@ unsafe extern "C" {
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cuda_cast_to_unsigned_64_ks32(
streams: CudaStreamsFFI,
output: *mut CudaRadixCiphertextFFI,
input: *mut CudaRadixCiphertextFFI,
mem_ptr: *mut i8,
target_num_blocks: u32,
input_is_signed: bool,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_cast_to_unsigned_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
}

View File

@@ -866,6 +866,9 @@ fn main() {
ParamType::Classical => {
benchmark::params_aliases::BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into()
},
ParamType::ClassicalKs32 => {
benchmark::params_aliases::BENCH_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M64.into()
},
_ => {
benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into()
}

View File

@@ -2719,7 +2719,13 @@ mod cuda {
rng_func: default_scalar
);
criterion_group!(cuda_ops_support_ks32, cuda_mul, cuda_add, cuda_sub,);
criterion_group!(
cuda_ops_support_ks32,
cuda_mul,
cuda_add,
cuda_sub,
cuda_div_rem,
);
criterion_group!(
unchecked_cuda_ops,

View File

@@ -1492,35 +1492,70 @@ pub(crate) unsafe fn cuda_backend_unchecked_bitop_assign<T: UnsignedInteger, B:
&mut radix_lwe_right_degrees,
&mut radix_lwe_right_noise_levels,
);
scratch_cuda_bitop_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
op as u32,
true,
noise_reduction_type as u32,
);
cuda_bitop_ciphertext_64(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_right,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
if TypeId::of::<T>() == TypeId::of::<u64>() {
scratch_cuda_bitop_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
op as u32,
true,
noise_reduction_type as u32,
);
cuda_bitop_ciphertext_64(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_right,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
} else if TypeId::of::<T>() == TypeId::of::<u32>() {
scratch_cuda_bitop_64_ks32(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
op as u32,
true,
noise_reduction_type as u32,
);
cuda_bitop_ciphertext_64_ks32(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_right,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
} else {
panic!("Unsupported KS dtype");
}
cleanup_cuda_integer_bitop(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
update_noise_degree(radix_lwe_left, &cuda_ffi_radix_lwe_left);
}
@@ -2093,37 +2128,73 @@ pub(crate) unsafe fn cuda_backend_unchecked_comparison<T: UnsignedInteger, B: Nu
&mut radix_lwe_right_noise_levels,
);
scratch_cuda_comparison_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
op as u32,
is_signed,
true,
noise_reduction_type as u32,
);
if TypeId::of::<T>() == TypeId::of::<u64>() {
scratch_cuda_comparison_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
op as u32,
is_signed,
true,
noise_reduction_type as u32,
);
cuda_comparison_ciphertext_64(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_out,
&raw const cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_right,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
cuda_comparison_ciphertext_64(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_out,
&raw const cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_right,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
} else if TypeId::of::<T>() == TypeId::of::<u32>() {
scratch_cuda_comparison_64_ks32(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
op as u32,
is_signed,
true,
noise_reduction_type as u32,
);
cuda_comparison_ciphertext_64_ks32(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_out,
&raw const cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_right,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
} else {
panic!("Unsupported KS datatype");
}
cleanup_cuda_integer_comparison(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
update_noise_degree(radix_lwe_out, &cuda_ffi_radix_lwe_out);
@@ -5565,35 +5636,71 @@ pub(crate) unsafe fn cuda_backend_unchecked_cmux<T: UnsignedInteger, B: Numeric>
&mut condition_noise_levels,
);
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
scratch_cuda_cmux_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
);
cuda_cmux_ciphertext_64(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_out,
&raw const cuda_ffi_condition,
&raw const cuda_ffi_radix_lwe_true,
&raw const cuda_ffi_radix_lwe_false,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
if TypeId::of::<T>() == TypeId::of::<u64>() {
scratch_cuda_cmux_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
);
cuda_cmux_ciphertext_64(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_out,
&raw const cuda_ffi_condition,
&raw const cuda_ffi_radix_lwe_true,
&raw const cuda_ffi_radix_lwe_false,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
} else if TypeId::of::<T>() == TypeId::of::<u32>() {
scratch_cuda_cmux_64_ks32(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
);
cuda_cmux_ciphertext_64_ks32(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_out,
&raw const cuda_ffi_condition,
&raw const cuda_ffi_radix_lwe_true,
&raw const cuda_ffi_radix_lwe_false,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
} else {
panic!("Unsupported KS dtype");
}
cleanup_cuda_cmux(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
update_noise_degree(radix_lwe_out, &cuda_ffi_radix_lwe_out);
}
@@ -6963,38 +7070,76 @@ pub(crate) unsafe fn cuda_backend_unchecked_unsigned_overflowing_sub_assign<
.collect();
let cuda_ffi_carry_in =
prepare_cuda_radix_ffi(carry_in, &mut carry_in_degrees, &mut carry_in_noise_levels);
scratch_cuda_integer_overflowing_sub_64_inplace(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension,
lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
compute_overflow as u32,
true,
noise_reduction_type as u32,
);
cuda_integer_overflowing_sub_64_inplace(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_right,
&raw mut cuda_ffi_carry_out,
&raw const cuda_ffi_carry_in,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
compute_overflow as u32,
uses_input_borrow,
);
if TypeId::of::<T>() == TypeId::of::<u64>() {
scratch_cuda_integer_overflowing_sub_64_inplace(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension,
lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
compute_overflow as u32,
true,
noise_reduction_type as u32,
);
cuda_integer_overflowing_sub_64_inplace(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_right,
&raw mut cuda_ffi_carry_out,
&raw const cuda_ffi_carry_in,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
compute_overflow as u32,
uses_input_borrow,
);
} else if TypeId::of::<T>() == TypeId::of::<u32>() {
scratch_cuda_integer_overflowing_sub_64_ks32_inplace(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension,
lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
compute_overflow as u32,
true,
noise_reduction_type as u32,
);
cuda_integer_overflowing_sub_64_ks32_inplace(
streams.ffi(),
&raw mut cuda_ffi_radix_lwe_left,
&raw const cuda_ffi_radix_lwe_right,
&raw mut cuda_ffi_carry_out,
&raw const cuda_ffi_carry_in,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
compute_overflow as u32,
uses_input_borrow,
);
} else {
panic!("Unsupported KS dtype");
}
cleanup_cuda_integer_overflowing_sub(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
update_noise_degree(radix_lwe_left, &cuda_ffi_radix_lwe_left);
update_noise_degree(carry_out, &cuda_ffi_carry_out);
@@ -8291,34 +8436,68 @@ pub(crate) unsafe fn cuda_backend_boolean_bitnot_assign<T: UnsignedInteger, B: N
&mut ciphertext_noise_levels,
);
scratch_cuda_boolean_bitnot_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
1u32,
is_unchecked,
true,
noise_reduction_type as u32,
);
if TypeId::of::<T>() == TypeId::of::<u64>() {
scratch_cuda_boolean_bitnot_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
1u32,
is_unchecked,
true,
noise_reduction_type as u32,
);
cuda_boolean_bitnot_ciphertext_64(
streams.ffi(),
&raw mut cuda_ffi_ciphertext,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
} else if TypeId::of::<T>() == TypeId::of::<u32>() {
scratch_cuda_boolean_bitnot_64_ks32(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
1u32,
is_unchecked,
true,
noise_reduction_type as u32,
);
cuda_boolean_bitnot_ciphertext_64_ks32(
streams.ffi(),
&raw mut cuda_ffi_ciphertext,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
} else {
panic!("Unsupported KS dype");
}
cuda_boolean_bitnot_ciphertext_64(
streams.ffi(),
&raw mut cuda_ffi_ciphertext,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
cleanup_cuda_boolean_bitnot(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
update_noise_degree(ciphertext, &cuda_ffi_ciphertext);
}
@@ -8792,39 +8971,75 @@ pub(crate) unsafe fn cuda_backend_cast_to_unsigned<T: UnsignedInteger, B: Numeri
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
scratch_cuda_cast_to_unsigned_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_input_blocks,
target_num_blocks,
input_is_signed,
requires_full_propagate,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
);
if TypeId::of::<T>() == TypeId::of::<u32>() {
scratch_cuda_cast_to_unsigned_64_ks32(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_input_blocks,
target_num_blocks,
input_is_signed,
requires_full_propagate,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
);
cuda_cast_to_unsigned_64(
streams.ffi(),
&raw mut cuda_ffi_output,
&raw mut cuda_ffi_input,
mem_ptr,
target_num_blocks,
input_is_signed,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
cuda_cast_to_unsigned_64_ks32(
streams.ffi(),
&raw mut cuda_ffi_output,
&raw mut cuda_ffi_input,
mem_ptr,
target_num_blocks,
input_is_signed,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
} else if TypeId::of::<T>() == TypeId::of::<u64>() {
scratch_cuda_cast_to_unsigned_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_input_blocks,
target_num_blocks,
input_is_signed,
requires_full_propagate,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
);
cuda_cast_to_unsigned_64(
streams.ffi(),
&raw mut cuda_ffi_output,
&raw mut cuda_ffi_input,
mem_ptr,
target_num_blocks,
input_is_signed,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
}
cleanup_cuda_cast_to_unsigned_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));

View File

@@ -408,57 +408,103 @@ impl CudaServerKey {
is_unchecked: bool,
streams: &CudaStreams,
) {
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_boolean_bitnot_assign(
streams,
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
is_unchecked,
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
d_bsk.output_lwe_dimension(),
d_bsk.input_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
match &self.key_switching_key {
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_boolean_bitnot_assign(
streams,
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
is_unchecked,
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
d_bsk.output_lwe_dimension(),
d_bsk.input_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_boolean_bitnot_assign(
streams,
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
is_unchecked,
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
d_multibit_bsk.output_lwe_dimension(),
d_multibit_bsk.input_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_boolean_bitnot_assign(
streams,
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
is_unchecked,
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
d_multibit_bsk.output_lwe_dimension(),
d_multibit_bsk.input_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
},
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_boolean_bitnot_assign(
streams,
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
is_unchecked,
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
d_bsk.output_lwe_dimension(),
d_bsk.input_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_boolean_bitnot_assign(
streams,
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
is_unchecked,
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
d_multibit_bsk.output_lwe_dimension(),
d_multibit_bsk.input_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
},
}
}
@@ -534,61 +580,111 @@ impl CudaServerKey {
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_bitop_assign(
streams,
ct_left.as_mut(),
ct_right.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
match &self.key_switching_key {
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_bitop_assign(
streams,
ct_left.as_mut(),
ct_right.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_bitop_assign(
streams,
ct_left.as_mut(),
ct_right.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_bitop_assign(
streams,
ct_left.as_mut(),
ct_right.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
},
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_bitop_assign(
streams,
ct_left.as_mut(),
ct_right.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_bitop_assign(
streams,
ct_left.as_mut(),
ct_right.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
},
}
}

View File

@@ -20,64 +20,117 @@ impl CudaServerKey {
let mut result: T = self
.create_trivial_zero_radix(true_ct.as_ref().d_blocks.lwe_ciphertext_count().0, stream);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_cmux(
stream,
result.as_mut(),
condition,
true_ct.as_ref(),
false_ct.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
match &self.key_switching_key {
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_cmux(
stream,
result.as_mut(),
condition,
true_ct.as_ref(),
false_ct.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_cmux(
stream,
result.as_mut(),
condition,
true_ct.as_ref(),
false_ct.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_cmux(
stream,
result.as_mut(),
condition,
true_ct.as_ref(),
false_ct.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
},
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_cmux(
stream,
result.as_mut(),
condition,
true_ct.as_ref(),
false_ct.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_cmux(
stream,
result.as_mut(),
condition,
true_ct.as_ref(),
false_ct.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
},
}
result
}

View File

@@ -45,65 +45,117 @@ impl CudaServerKey {
let mut result =
CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(block, ct_info));
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
match &self.key_switching_key {
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_comparison(
streams,
result.as_mut().as_mut(),
ct_left.as_ref(),
ct_right.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
T::IS_SIGNED,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_comparison(
streams,
result.as_mut().as_mut(),
ct_left.as_ref(),
ct_right.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
T::IS_SIGNED,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
},
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_comparison(
streams,
result.as_mut().as_mut(),
ct_left.as_ref(),
ct_right.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
T::IS_SIGNED,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_comparison(
streams,
result.as_mut().as_mut(),
ct_left.as_ref(),
ct_right.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
T::IS_SIGNED,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
},
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_comparison(
streams,
result.as_mut().as_mut(),
ct_left.as_ref(),
ct_right.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
T::IS_SIGNED,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_comparison(
streams,
result.as_mut().as_mut(),
ct_left.as_ref(),
ct_right.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
T::IS_SIGNED,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
result
}

View File

@@ -1019,61 +1019,109 @@ impl CudaServerKey {
self.create_trivial_zero_radix(target_num_blocks, streams);
let requires_full_propagate = !source.block_carries_are_empty();
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_cast_to_unsigned(
streams,
result.as_mut(),
source.as_mut(),
T::IS_SIGNED,
requires_full_propagate,
target_num_blocks as u32,
&d_bsk.d_vec,
&computing_ks_key.d_vec,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
match &self.key_switching_key {
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_cast_to_unsigned(
streams,
result.as_mut(),
source.as_mut(),
T::IS_SIGNED,
requires_full_propagate,
target_num_blocks as u32,
&d_bsk.d_vec,
&computing_ks_key.d_vec,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_cast_to_unsigned(
streams,
result.as_mut(),
source.as_mut(),
T::IS_SIGNED,
requires_full_propagate,
target_num_blocks as u32,
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_cast_to_unsigned(
streams,
result.as_mut(),
source.as_mut(),
T::IS_SIGNED,
requires_full_propagate,
target_num_blocks as u32,
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
},
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_cast_to_unsigned(
streams,
result.as_mut(),
source.as_mut(),
T::IS_SIGNED,
requires_full_propagate,
target_num_blocks as u32,
&d_bsk.d_vec,
&computing_ks_key.d_vec,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_cast_to_unsigned(
streams,
result.as_mut(),
source.as_mut(),
T::IS_SIGNED,
requires_full_propagate,
target_num_blocks as u32,
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
},
}
result
}

View File

@@ -296,64 +296,118 @@ impl CudaServerKey {
let aux_block: CudaUnsignedRadixCiphertext = self.create_trivial_zero_radix(1, stream);
let in_carry_dvec =
INPUT_BORROW.map_or_else(|| aux_block.as_ref(), |block| block.as_ref().as_ref());
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
stream,
ciphertext,
rhs.as_ref(),
overflow_block.as_mut(),
in_carry_dvec,
&d_bsk.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
compute_overflow,
uses_input_borrow,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
match &self.key_switching_key {
CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
stream,
ciphertext,
rhs.as_ref(),
overflow_block.as_mut(),
in_carry_dvec,
&d_bsk.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
compute_overflow,
uses_input_borrow,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
stream,
ciphertext,
rhs.as_ref(),
overflow_block.as_mut(),
in_carry_dvec,
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
compute_overflow,
uses_input_borrow,
None,
);
}
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
stream,
ciphertext,
rhs.as_ref(),
overflow_block.as_mut(),
in_carry_dvec,
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
compute_overflow,
uses_input_borrow,
None,
);
},
CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
stream,
ciphertext,
rhs.as_ref(),
overflow_block.as_mut(),
in_carry_dvec,
&d_bsk.d_vec,
&computing_ks_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
compute_overflow,
uses_input_borrow,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
stream,
ciphertext,
rhs.as_ref(),
overflow_block.as_mut(),
in_carry_dvec,
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
compute_overflow,
uses_input_borrow,
None,
);
}
}
}
},
}
let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(overflow_block.ciphertext);
(ct_res, ct_overflowed)