mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(gpu): support keyswitch 64/32
This commit is contained in:
@@ -32,7 +32,7 @@ void cuda_integer_aes_ctr_encrypt_64(CudaStreamsFFI streams,
|
||||
host_integer_aes_ctr_encrypt<uint64_t>(
|
||||
CudaStreams(streams), output, iv, round_keys, counter_bits_le_all_blocks,
|
||||
num_aes_inputs, (int_aes_encrypt_buffer<uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)ksks);
|
||||
(uint32_t **)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_aes_encrypt_64(CudaStreamsFFI streams,
|
||||
@@ -74,7 +74,7 @@ void cuda_integer_key_expansion_64(CudaStreamsFFI streams,
|
||||
|
||||
host_integer_key_expansion<uint64_t>(
|
||||
CudaStreams(streams), expanded_keys, key,
|
||||
(int_key_expansion_buffer<uint64_t> *)mem_ptr, bsks, (uint64_t **)ksks);
|
||||
(int_key_expansion_buffer<uint64_t> *)mem_ptr, bsks, (uint32_t **)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_key_expansion_64(CudaStreamsFFI streams,
|
||||
|
||||
@@ -105,11 +105,11 @@ aes_xor(CudaStreams streams, int_aes_encrypt_buffer<Torus> *mem,
|
||||
* result.
|
||||
*
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ __forceinline__ void
|
||||
aes_flush_inplace(CudaStreams streams, CudaRadixCiphertextFFI *data,
|
||||
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
streams, data, data, bsks, ksks, mem->luts->flush_lut,
|
||||
@@ -121,10 +121,10 @@ aes_flush_inplace(CudaStreams streams, CudaRadixCiphertextFFI *data,
|
||||
* ciphertext, then flushes the result to ensure it's a valid bit.
|
||||
*
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ __forceinline__ void aes_scalar_add_one_flush_inplace(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *data,
|
||||
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, Torus *const *ksks) {
|
||||
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
host_integer_radix_add_scalar_one_inplace<Torus>(
|
||||
streams, data, mem->params.message_modulus, mem->params.carry_modulus);
|
||||
@@ -142,11 +142,11 @@ __host__ __forceinline__ void aes_scalar_add_one_flush_inplace(
|
||||
* ciphertext locations.
|
||||
*
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void
|
||||
batch_vec_flush_inplace(CudaStreams streams, CudaRadixCiphertextFFI **targets,
|
||||
size_t count, int_aes_encrypt_buffer<Torus> *mem,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
uint32_t num_radix_blocks = targets[0]->num_radix_blocks;
|
||||
|
||||
@@ -185,13 +185,13 @@ batch_vec_flush_inplace(CudaStreams streams, CudaRadixCiphertextFFI **targets,
|
||||
* Batches multiple "and" operations into a single, large launch.
|
||||
*
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void batch_vec_and_inplace(CudaStreams streams,
|
||||
CudaRadixCiphertextFFI **outs,
|
||||
CudaRadixCiphertextFFI **lhs,
|
||||
CudaRadixCiphertextFFI **rhs, size_t count,
|
||||
int_aes_encrypt_buffer<Torus> *mem,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
uint32_t num_aes_inputs = outs[0]->num_radix_blocks;
|
||||
|
||||
@@ -274,13 +274,13 @@ __host__ void batch_vec_and_inplace(CudaStreams streams,
|
||||
* [ptr] -> [R2b0, R2b1, R2b2, R2b3, R2b4, R2b5, R2b6, R2b7]
|
||||
* ...
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void vectorized_sbox_n_bytes(CudaStreams streams,
|
||||
CudaRadixCiphertextFFI **sbox_io_bytes,
|
||||
uint32_t num_bytes_parallel,
|
||||
uint32_t num_aes_inputs,
|
||||
int_aes_encrypt_buffer<Torus> *mem,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
uint32_t num_sbox_blocks = num_bytes_parallel * num_aes_inputs;
|
||||
|
||||
@@ -702,12 +702,12 @@ __host__ void vectorized_mul_by_2(CudaStreams streams,
|
||||
* [ s'_3 ] [ 03 01 01 02 ] [ s_3 ]
|
||||
*
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void vectorized_mix_columns(CudaStreams streams,
|
||||
CudaRadixCiphertextFFI *s_bits,
|
||||
uint32_t num_aes_inputs,
|
||||
int_aes_encrypt_buffer<Torus> *mem,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
constexpr uint32_t BITS_PER_BYTE = 8;
|
||||
constexpr uint32_t BYTES_PER_COLUMN = 4;
|
||||
@@ -842,11 +842,11 @@ __host__ void vectorized_mix_columns(CudaStreams streams,
|
||||
* - AddRoundKey
|
||||
*
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void vectorized_aes_encrypt_inplace(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *all_states_bitsliced,
|
||||
CudaRadixCiphertextFFI const *round_keys, uint32_t num_aes_inputs,
|
||||
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, Torus *const *ksks) {
|
||||
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
constexpr uint32_t BITS_PER_BYTE = 8;
|
||||
constexpr uint32_t STATE_BYTES = 16;
|
||||
@@ -987,11 +987,11 @@ __host__ void vectorized_aes_encrypt_inplace(
|
||||
* The "transposed_states" buffer is updated in-place with the sum bits $S_i$.
|
||||
*
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void vectorized_aes_full_adder_inplace(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *transposed_states,
|
||||
const Torus *counter_bits_le_all_blocks, uint32_t num_aes_inputs,
|
||||
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, Torus *const *ksks) {
|
||||
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
constexpr uint32_t NUM_BITS = 128;
|
||||
|
||||
@@ -1091,12 +1091,12 @@ __host__ void vectorized_aes_full_adder_inplace(
|
||||
* +---------------------------------+
|
||||
*
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_aes_ctr_encrypt(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *iv, CudaRadixCiphertextFFI const *round_keys,
|
||||
const Torus *counter_bits_le_all_blocks, uint32_t num_aes_inputs,
|
||||
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, Torus *const *ksks) {
|
||||
int_aes_encrypt_buffer<Torus> *mem, void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
constexpr uint32_t NUM_BITS = 128;
|
||||
|
||||
@@ -1148,13 +1148,13 @@ uint64_t scratch_cuda_integer_key_expansion(
|
||||
* - If (i % 4 == 0): w_i = w_{i-4} + SubWord(RotWord(w_{i-1})) + Rcon[i/4]
|
||||
* - If (i % 4 != 0): w_i = w_{i-4} + w_{i-1}
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_key_expansion(CudaStreams streams,
|
||||
CudaRadixCiphertextFFI *expanded_keys,
|
||||
CudaRadixCiphertextFFI const *key,
|
||||
int_key_expansion_buffer<Torus> *mem,
|
||||
void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
constexpr uint32_t BITS_PER_WORD = 32;
|
||||
constexpr uint32_t BITS_PER_BYTE = 8;
|
||||
|
||||
@@ -48,12 +48,12 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
|
||||
// in two parts, a constant part is calculated before the loop, and a variable
|
||||
// part is calculated inside the loop. This seems to help with the register
|
||||
// pressure as well.
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__global__ void
|
||||
keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
|
||||
const Torus *__restrict__ lwe_array_in,
|
||||
const Torus *__restrict__ lwe_input_indexes,
|
||||
const Torus *__restrict__ ksk, uint32_t lwe_dimension_in,
|
||||
const KSTorus *__restrict__ ksk, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count) {
|
||||
const int tid = threadIdx.x + blockIdx.y * blockDim.x;
|
||||
const int shmem_index = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
@@ -107,11 +107,11 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_keyswitch_lwe_ciphertext_vector(
|
||||
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
|
||||
Torus const *lwe_output_indexes, Torus const *lwe_array_in,
|
||||
Torus const *lwe_input_indexes, Torus const *ksk, uint32_t lwe_dimension_in,
|
||||
Torus const *lwe_input_indexes, KSTorus const *ksk, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_samples) {
|
||||
|
||||
@@ -135,19 +135,19 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
|
||||
dim3 grid(num_samples, num_blocks_per_sample, 1);
|
||||
dim3 threads(num_threads_x, num_threads_y, 1);
|
||||
|
||||
keyswitch<Torus><<<grid, threads, shared_mem, stream>>>(
|
||||
keyswitch<Torus, KSTorus><<<grid, threads, shared_mem, stream>>>(
|
||||
lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk,
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void execute_keyswitch_async(CudaStreams streams,
|
||||
const LweArrayVariant<Torus> &lwe_array_out,
|
||||
const LweArrayVariant<Torus> &lwe_output_indexes,
|
||||
const LweArrayVariant<Torus> &lwe_array_in,
|
||||
const LweArrayVariant<Torus> &lwe_input_indexes,
|
||||
Torus *const *ksks, uint32_t lwe_dimension_in,
|
||||
KSTorus *const *ksks, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples) {
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ void cuda_integer_abs_inplace_radix_ciphertext_kb_64(
|
||||
auto mem = (int_abs_buffer<uint64_t> *)mem_ptr;
|
||||
|
||||
host_integer_abs_kb<uint64_t>(CudaStreams(streams), ct, bsks,
|
||||
(uint64_t **)(ksks), mem, is_signed);
|
||||
(uint32_t **)(ksks), mem, is_signed);
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_abs_inplace(CudaStreamsFFI streams,
|
||||
|
||||
@@ -32,7 +32,7 @@ __host__ uint64_t scratch_cuda_integer_abs_kb(
|
||||
template <typename Torus>
|
||||
__host__ void
|
||||
host_integer_abs_kb(CudaStreams streams, CudaRadixCiphertextFFI *ct,
|
||||
void *const *bsks, uint64_t *const *ksks,
|
||||
void *const *bsks, uint32_t *const *ksks,
|
||||
int_abs_buffer<uint64_t> *mem_ptr, bool is_signed) {
|
||||
if (!is_signed)
|
||||
return;
|
||||
|
||||
@@ -27,7 +27,7 @@ void cuda_bitop_integer_radix_ciphertext_kb_64(
|
||||
|
||||
host_integer_radix_bitop_kb<uint64_t>(
|
||||
CudaStreams(streams), lwe_array_out, lwe_array_1, lwe_array_2,
|
||||
(int_bitop_buffer<uint64_t> *)mem_ptr, bsks, (uint64_t **)(ksks));
|
||||
(int_bitop_buffer<uint64_t> *)mem_ptr, bsks, (uint32_t **)(ksks));
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_bitop(CudaStreamsFFI streams, int8_t **mem_ptr_void) {
|
||||
|
||||
@@ -11,12 +11,12 @@
|
||||
#include "utils/kernel_dimensions.cuh"
|
||||
#include <omp.h>
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_bitop_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_1,
|
||||
CudaRadixCiphertextFFI const *lwe_array_2, int_bitop_buffer<Torus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
PANIC_IF_FALSE(
|
||||
lwe_array_out->num_radix_blocks == lwe_array_1->num_radix_blocks &&
|
||||
|
||||
@@ -45,7 +45,7 @@ void cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
|
||||
host_extend_radix_with_sign_msb<uint64_t>(
|
||||
CudaStreams(streams), output, input,
|
||||
(int_extend_radix_with_sign_msb_buffer<uint64_t> *)mem_ptr,
|
||||
num_additional_blocks, bsks, (uint64_t **)ksks);
|
||||
num_additional_blocks, bsks, (uint32_t **)ksks);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
|
||||
@@ -50,12 +50,12 @@ __host__ uint64_t scratch_extend_radix_with_sign_msb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_extend_radix_with_sign_msb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *input,
|
||||
int_extend_radix_with_sign_msb_buffer<Torus> *mem_ptr,
|
||||
uint32_t num_additional_blocks, void *const *bsks, Torus *const *ksks) {
|
||||
uint32_t num_additional_blocks, void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
if (num_additional_blocks == 0) {
|
||||
PUSH_RANGE("cast/extend no addblocks")
|
||||
|
||||
@@ -34,7 +34,7 @@ void cuda_cmux_integer_radix_ciphertext_kb_64(
|
||||
host_integer_radix_cmux_kb<uint64_t>(
|
||||
CudaStreams(streams), lwe_array_out, lwe_condition, lwe_array_true,
|
||||
lwe_array_false, (int_cmux_buffer<uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)(ksks));
|
||||
(uint32_t **)(ksks));
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
|
||||
@@ -4,14 +4,14 @@
|
||||
#include "integer.cuh"
|
||||
#include "radix_ciphertext.cuh"
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void zero_out_if(CudaStreams streams,
|
||||
CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_input,
|
||||
CudaRadixCiphertextFFI const *lwe_condition,
|
||||
int_zero_out_if_buffer<Torus> *mem_ptr,
|
||||
int_radix_lut<Torus> *predicate, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks) {
|
||||
PANIC_IF_FALSE(
|
||||
lwe_array_out->num_radix_blocks >= num_radix_blocks &&
|
||||
lwe_array_input->num_radix_blocks >= num_radix_blocks,
|
||||
@@ -40,13 +40,13 @@ __host__ void zero_out_if(CudaStreams streams,
|
||||
num_radix_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_cmux_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_condition,
|
||||
CudaRadixCiphertextFFI const *lwe_array_true,
|
||||
CudaRadixCiphertextFFI const *lwe_array_false,
|
||||
int_cmux_buffer<Torus> *mem_ptr, void *const *bsks, Torus *const *ksks) {
|
||||
int_cmux_buffer<Torus> *mem_ptr, void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
if (lwe_array_out->num_radix_blocks != lwe_array_true->num_radix_blocks)
|
||||
PANIC("Cuda error: input and output num radix blocks must be the same")
|
||||
|
||||
@@ -56,7 +56,7 @@ void cuda_comparison_integer_radix_ciphertext_kb_64(
|
||||
case NE:
|
||||
host_integer_radix_equality_check_kb<uint64_t>(
|
||||
CudaStreams(streams), lwe_array_out, lwe_array_1, lwe_array_2, buffer,
|
||||
bsks, (uint64_t **)(ksks), num_radix_blocks);
|
||||
bsks, (uint32_t **)(ksks), num_radix_blocks);
|
||||
break;
|
||||
case GT:
|
||||
case GE:
|
||||
@@ -67,7 +67,7 @@ void cuda_comparison_integer_radix_ciphertext_kb_64(
|
||||
"even.")
|
||||
host_integer_radix_difference_check_kb<uint64_t>(
|
||||
CudaStreams(streams), lwe_array_out, lwe_array_1, lwe_array_2, buffer,
|
||||
buffer->diff_buffer->operator_f, bsks, (uint64_t **)(ksks),
|
||||
buffer->diff_buffer->operator_f, bsks, (uint32_t **)(ksks),
|
||||
num_radix_blocks);
|
||||
break;
|
||||
case MAX:
|
||||
@@ -76,7 +76,7 @@ void cuda_comparison_integer_radix_ciphertext_kb_64(
|
||||
PANIC("Cuda error (max/min): the number of radix blocks has to be even.")
|
||||
host_integer_radix_maxmin_kb<uint64_t>(
|
||||
CudaStreams(streams), lwe_array_out, lwe_array_1, lwe_array_2, buffer,
|
||||
bsks, (uint64_t **)(ksks), num_radix_blocks);
|
||||
bsks, (uint32_t **)(ksks), num_radix_blocks);
|
||||
break;
|
||||
default:
|
||||
PANIC("Cuda error: integer operation not supported")
|
||||
@@ -124,7 +124,7 @@ void cuda_integer_are_all_comparisons_block_true_kb_64(
|
||||
|
||||
host_integer_are_all_comparisons_block_true_kb<uint64_t>(
|
||||
CudaStreams(streams), lwe_array_out, lwe_array_in, buffer, bsks,
|
||||
(uint64_t **)(ksks), num_radix_blocks);
|
||||
(uint32_t **)(ksks), num_radix_blocks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_are_all_comparisons_block_true(
|
||||
@@ -166,7 +166,7 @@ void cuda_integer_is_at_least_one_comparisons_block_true_kb_64(
|
||||
|
||||
host_integer_is_at_least_one_comparisons_block_true_kb<uint64_t>(
|
||||
CudaStreams(streams), lwe_array_out, lwe_array_in, buffer, bsks,
|
||||
(uint64_t **)(ksks), num_radix_blocks);
|
||||
(uint32_t **)(ksks), num_radix_blocks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_is_at_least_one_comparisons_block_true(
|
||||
|
||||
@@ -56,12 +56,12 @@ __host__ void accumulate_all_blocks(cudaStream_t stream, uint32_t gpu_index,
|
||||
* blocks are 1 otherwise the block encrypts 0
|
||||
*
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void are_all_comparisons_block_true(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks) {
|
||||
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
|
||||
PANIC("Cuda error: input and output lwe dimensions must be the same")
|
||||
@@ -184,12 +184,12 @@ __host__ void are_all_comparisons_block_true(
|
||||
* It writes in lwe_array_out a single lwe ciphertext encrypting 1 if at least
|
||||
* one input ciphertext encrypts 1 otherwise encrypts 0
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void is_at_least_one_comparisons_block_true(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks) {
|
||||
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
|
||||
PANIC("Cuda error: input lwe dimensions must be the same")
|
||||
@@ -253,12 +253,12 @@ __host__ void is_at_least_one_comparisons_block_true(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_compare_blocks_with_zero(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, int32_t num_radix_blocks,
|
||||
KSTorus *const *ksks, int32_t num_radix_blocks,
|
||||
int_radix_lut<Torus> *zero_comparison) {
|
||||
|
||||
if (num_radix_blocks == 0)
|
||||
@@ -320,13 +320,13 @@ __host__ void host_compare_blocks_with_zero(
|
||||
reset_radix_ciphertext_blocks(lwe_array_out, num_sum_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_equality_check_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_1,
|
||||
CudaRadixCiphertextFFI const *lwe_array_2,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks) {
|
||||
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_1->lwe_dimension ||
|
||||
lwe_array_out->lwe_dimension != lwe_array_2->lwe_dimension)
|
||||
@@ -348,13 +348,13 @@ __host__ void host_integer_radix_equality_check_kb(
|
||||
mem_ptr, bsks, ksks, num_radix_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void compare_radix_blocks_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_left,
|
||||
CudaRadixCiphertextFFI const *lwe_array_right,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks) {
|
||||
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_left->lwe_dimension ||
|
||||
lwe_array_out->lwe_dimension != lwe_array_right->lwe_dimension)
|
||||
@@ -400,13 +400,13 @@ __host__ void compare_radix_blocks_kb(
|
||||
// Reduces a vec containing shortint blocks that encrypts a sign
|
||||
// (inferior, equal, superior) to one single shortint block containing the
|
||||
// final sign
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void
|
||||
tree_sign_reduction(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI *lwe_block_comparisons,
|
||||
int_tree_sign_reduction_buffer<Torus> *tree_buffer,
|
||||
std::function<Torus(Torus)> sign_handler_f,
|
||||
void *const *bsks, Torus *const *ksks,
|
||||
void *const *bsks, KSTorus *const *ksks,
|
||||
uint32_t num_radix_blocks) {
|
||||
|
||||
if (lwe_array_out->lwe_dimension != lwe_block_comparisons->lwe_dimension)
|
||||
@@ -489,14 +489,14 @@ tree_sign_reduction(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
streams, lwe_array_out, y, bsks, ksks, last_lut, 1);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_difference_check_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_left,
|
||||
CudaRadixCiphertextFFI const *lwe_array_right,
|
||||
int_comparison_buffer<Torus> *mem_ptr,
|
||||
std::function<Torus(Torus)> reduction_lut_f, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks) {
|
||||
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_left->lwe_dimension ||
|
||||
lwe_array_out->lwe_dimension != lwe_array_right->lwe_dimension)
|
||||
@@ -657,13 +657,13 @@ __host__ uint64_t scratch_cuda_integer_radix_comparison_check_kb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_maxmin_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_left,
|
||||
CudaRadixCiphertextFFI const *lwe_array_right,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks) {
|
||||
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_left->lwe_dimension ||
|
||||
lwe_array_out->lwe_dimension != lwe_array_right->lwe_dimension)
|
||||
@@ -685,12 +685,12 @@ __host__ void host_integer_radix_maxmin_kb(
|
||||
lwe_array_right, mem_ptr->cmux_buffer, bsks, ksks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_are_all_comparisons_block_true_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks) {
|
||||
|
||||
// It returns a block encrypting 1 if all input blocks are 1
|
||||
// otherwise the block encrypts 0
|
||||
@@ -698,12 +698,12 @@ __host__ void host_integer_are_all_comparisons_block_true_kb(
|
||||
mem_ptr, bsks, ksks, num_radix_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_is_at_least_one_comparisons_block_true_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks) {
|
||||
|
||||
// It returns a block encrypting 1 if all input blocks are 1
|
||||
// otherwise the block encrypts 0
|
||||
|
||||
@@ -30,7 +30,7 @@ void cuda_integer_div_rem_radix_ciphertext_kb_64(
|
||||
|
||||
host_integer_div_rem_kb<uint64_t>(CudaStreams(streams), quotient, remainder,
|
||||
numerator, divisor, is_signed, bsks,
|
||||
(uint64_t **)(ksks), mem);
|
||||
(uint32_t **)(ksks), mem);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *quotient,
|
||||
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
|
||||
CudaRadixCiphertextFFI const *divisor, void *const *bsks,
|
||||
uint64_t *const *ksks, unsigned_int_div_rem_2_2_memory<uint64_t> *mem_ptr) {
|
||||
uint32_t *const *ksks, unsigned_int_div_rem_2_2_memory<uint64_t> *mem_ptr) {
|
||||
|
||||
if (streams.count() < 4) {
|
||||
PANIC("GPU count should be greater than 4 when using div_rem_2_2");
|
||||
@@ -480,7 +480,7 @@ __host__ void host_unsigned_integer_div_rem_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *quotient,
|
||||
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
|
||||
CudaRadixCiphertextFFI const *divisor, void *const *bsks,
|
||||
uint64_t *const *ksks, unsigned_int_div_rem_memory<uint64_t> *mem_ptr) {
|
||||
uint32_t *const *ksks, unsigned_int_div_rem_memory<uint64_t> *mem_ptr) {
|
||||
|
||||
if (remainder->num_radix_blocks != numerator->num_radix_blocks ||
|
||||
remainder->num_radix_blocks != divisor->num_radix_blocks ||
|
||||
@@ -910,7 +910,7 @@ __host__ void host_integer_div_rem_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *quotient,
|
||||
CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator,
|
||||
CudaRadixCiphertextFFI const *divisor, bool is_signed, void *const *bsks,
|
||||
uint64_t *const *ksks, int_div_rem_memory<uint64_t> *int_mem_ptr) {
|
||||
uint32_t *const *ksks, int_div_rem_memory<uint64_t> *int_mem_ptr) {
|
||||
if (remainder->num_radix_blocks != numerator->num_radix_blocks ||
|
||||
remainder->num_radix_blocks != divisor->num_radix_blocks ||
|
||||
remainder->num_radix_blocks != quotient->num_radix_blocks)
|
||||
|
||||
@@ -31,10 +31,10 @@ void cuda_integer_count_of_consecutive_bits_kb_64(
|
||||
CudaRadixCiphertextFFI const *input_ct, int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
|
||||
host_integer_count_of_consecutive_bits<uint64_t>(
|
||||
host_integer_count_of_consecutive_bits<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), output_ct, input_ct,
|
||||
(int_count_of_consecutive_bits_buffer<uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)ksks);
|
||||
(uint32_t **)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_count_of_consecutive_bits_kb_64(
|
||||
@@ -82,10 +82,10 @@ void cuda_integer_ilog2_kb_64(
|
||||
CudaRadixCiphertextFFI const *trivial_ct_m_minus_1_block, int8_t *mem_ptr,
|
||||
void *const *bsks, void *const *ksks) {
|
||||
|
||||
host_integer_ilog2<uint64_t>(
|
||||
host_integer_ilog2<uint64_t, uint32_t>(
|
||||
CudaStreams(streams), output_ct, input_ct, trivial_ct_neg_n, trivial_ct_2,
|
||||
trivial_ct_m_minus_1_block, (int_ilog2_buffer<uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)ksks);
|
||||
(uint32_t **)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_ilog2_kb_64(CudaStreamsFFI streams,
|
||||
|
||||
@@ -5,11 +5,11 @@
|
||||
#include "integer/integer_utilities.h"
|
||||
#include "multiplication.cuh"
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_prepare_count_of_consecutive_bits(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *ciphertext,
|
||||
int_prepare_count_of_consecutive_bits_buffer<Torus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
auto tmp = mem_ptr->tmp_ct;
|
||||
|
||||
@@ -41,12 +41,12 @@ __host__ uint64_t scratch_integer_count_of_consecutive_bits(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_count_of_consecutive_bits(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *output_ct,
|
||||
CudaRadixCiphertextFFI const *input_ct,
|
||||
int_count_of_consecutive_bits_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
auto params = mem_ptr->params;
|
||||
auto ct_prepared = mem_ptr->ct_prepared;
|
||||
@@ -97,7 +97,7 @@ __host__ uint64_t scratch_integer_ilog2(CudaStreams streams,
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void
|
||||
host_integer_ilog2(CudaStreams streams, CudaRadixCiphertextFFI *output_ct,
|
||||
CudaRadixCiphertextFFI const *input_ct,
|
||||
@@ -105,7 +105,7 @@ host_integer_ilog2(CudaStreams streams, CudaRadixCiphertextFFI *output_ct,
|
||||
CudaRadixCiphertextFFI const *trivial_ct_2,
|
||||
CudaRadixCiphertextFFI const *trivial_ct_m_minus_1_block,
|
||||
int_ilog2_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
// Prepare the input ciphertext by computing the number of consecutive
|
||||
// leading zeros for each of its blocks.
|
||||
|
||||
@@ -11,7 +11,7 @@ void cuda_full_propagation_64_inplace(CudaStreamsFFI streams,
|
||||
(int_fullprop_buffer<uint64_t> *)mem_ptr;
|
||||
|
||||
host_full_propagate_inplace<uint64_t>(CudaStreams(streams), input_blocks,
|
||||
buffer, (uint64_t **)(ksks), bsks,
|
||||
buffer, (uint32_t **)(ksks), bsks,
|
||||
num_blocks);
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ void cuda_propagate_single_carry_kb_64_inplace(
|
||||
|
||||
host_propagate_single_carry<uint64_t>(
|
||||
CudaStreams(streams), lwe_array, carry_out, carry_in,
|
||||
(int_sc_prop_memory<uint64_t> *)mem_ptr, bsks, (uint64_t **)(ksks),
|
||||
(int_sc_prop_memory<uint64_t> *)mem_ptr, bsks, (uint32_t **)(ksks),
|
||||
requested_flag, uses_carry);
|
||||
}
|
||||
|
||||
@@ -118,7 +118,7 @@ void cuda_add_and_propagate_single_carry_kb_64_inplace(
|
||||
|
||||
host_add_and_propagate_single_carry<uint64_t>(
|
||||
CudaStreams(streams), lhs_array, rhs_array, carry_out, carry_in,
|
||||
(int_sc_prop_memory<uint64_t> *)mem_ptr, bsks, (uint64_t **)(ksks),
|
||||
(int_sc_prop_memory<uint64_t> *)mem_ptr, bsks, (uint32_t **)(ksks),
|
||||
requested_flag, uses_carry);
|
||||
}
|
||||
|
||||
@@ -133,7 +133,7 @@ void cuda_integer_overflowing_sub_kb_64_inplace(
|
||||
host_integer_overflowing_sub<uint64_t>(
|
||||
CudaStreams(streams), lhs_array, lhs_array, rhs_array, overflow_block,
|
||||
input_borrow, (int_borrow_prop_memory<uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)ksks, compute_overflow, uses_input_borrow);
|
||||
(uint32_t **)ksks, compute_overflow, uses_input_borrow);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
@@ -216,7 +216,7 @@ void cuda_apply_univariate_lut_kb_64(
|
||||
|
||||
host_apply_univariate_lut_kb<uint64_t>(
|
||||
CudaStreams(streams), output_radix_lwe, input_radix_lwe,
|
||||
(int_radix_lut<uint64_t> *)mem_ptr, (uint64_t **)(ksks), bsks);
|
||||
(int_radix_lut<uint64_t> *)mem_ptr, (uint32_t **)(ksks), bsks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_apply_univariate_lut_kb_64(CudaStreamsFFI streams,
|
||||
@@ -237,7 +237,7 @@ void cuda_apply_many_univariate_lut_kb_64(
|
||||
|
||||
host_apply_many_univariate_lut_kb<uint64_t>(
|
||||
CudaStreams(streams), output_radix_lwe, input_radix_lwe,
|
||||
(int_radix_lut<uint64_t> *)mem_ptr, (uint64_t **)(ksks), bsks,
|
||||
(int_radix_lut<uint64_t> *)mem_ptr, (uint32_t **)(ksks), bsks,
|
||||
num_many_lut, lut_stride);
|
||||
}
|
||||
|
||||
@@ -271,7 +271,7 @@ void cuda_apply_bivariate_lut_kb_64(
|
||||
host_apply_bivariate_lut_kb<uint64_t>(
|
||||
CudaStreams(streams), output_radix_lwe, input_radix_lwe_1,
|
||||
input_radix_lwe_2, (int_radix_lut<uint64_t> *)mem_ptr,
|
||||
(uint64_t **)(ksks), bsks, num_radix_blocks, shift);
|
||||
(uint32_t **)(ksks), bsks, num_radix_blocks, shift);
|
||||
}
|
||||
|
||||
void cleanup_cuda_apply_bivariate_lut_kb_64(CudaStreamsFFI streams,
|
||||
@@ -312,7 +312,7 @@ void cuda_integer_compute_prefix_sum_hillis_steele_64(
|
||||
|
||||
host_compute_prefix_sum_hillis_steele<uint64_t>(
|
||||
CudaStreams(streams), output_radix_lwe, generates_or_propagates,
|
||||
(int_radix_lut<uint64_t> *)mem_ptr, bsks, (uint64_t **)(ksks),
|
||||
(int_radix_lut<uint64_t> *)mem_ptr, bsks, (uint32_t **)(ksks),
|
||||
num_radix_blocks);
|
||||
}
|
||||
|
||||
@@ -390,7 +390,7 @@ void cuda_apply_noise_squashing_kb(
|
||||
PUSH_RANGE("apply noise squashing")
|
||||
integer_radix_apply_noise_squashing_kb<uint64_t>(
|
||||
CudaStreams(streams), output_radix_lwe, input_radix_lwe,
|
||||
(int_noise_squashing_lut<uint64_t> *)mem_ptr, bsks, (uint64_t **)ksks);
|
||||
(int_noise_squashing_lut<uint64_t> *)mem_ptr, bsks, (uint32_t **)ksks);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
|
||||
@@ -503,11 +503,11 @@ __host__ void host_pack_bivariate_blocks_with_single_block(
|
||||
/// num_radix_blocks corresponds to the number of blocks on which to apply the
|
||||
/// LUT In scalar bitops we use a number of blocks that may be lower or equal to
|
||||
/// the input and output numbers of blocks
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void integer_radix_apply_univariate_lookup_table_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in, void *const *bsks,
|
||||
Torus *const *ksks, int_radix_lut<Torus> *lut, uint32_t num_radix_blocks) {
|
||||
KSTorus *const *ksks, int_radix_lut<Torus> *lut, uint32_t num_radix_blocks) {
|
||||
PUSH_RANGE("apply lut")
|
||||
// apply_lookup_table
|
||||
auto params = lut->params;
|
||||
@@ -607,11 +607,11 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void integer_radix_apply_many_univariate_lookup_table_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in, void *const *bsks,
|
||||
Torus *const *ksks, int_radix_lut<Torus> *lut, uint32_t num_many_lut,
|
||||
KSTorus *const *ksks, int_radix_lut<Torus> *lut, uint32_t num_many_lut,
|
||||
uint32_t lut_stride) {
|
||||
PUSH_RANGE("apply many lut")
|
||||
// apply_lookup_table
|
||||
@@ -710,12 +710,12 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void integer_radix_apply_bivariate_lookup_table_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_1,
|
||||
CudaRadixCiphertextFFI const *lwe_array_2, void *const *bsks,
|
||||
Torus *const *ksks, int_radix_lut<Torus> *lut, uint32_t num_radix_blocks,
|
||||
KSTorus *const *ksks, int_radix_lut<Torus> *lut, uint32_t num_radix_blocks,
|
||||
uint32_t shift) {
|
||||
PUSH_RANGE("apply bivar lut")
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_1->lwe_dimension ||
|
||||
@@ -1269,11 +1269,11 @@ void generate_many_lut_device_accumulator(
|
||||
// block states: contains the propagation states for the different blocks
|
||||
// depending on the group it belongs to and the internal position within the
|
||||
// block.
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_compute_shifted_blocks_and_states(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array,
|
||||
int_shifted_blocks_and_states_memory<Torus> *mem, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t lut_stride, uint32_t num_many_lut) {
|
||||
KSTorus *const *ksks, uint32_t lut_stride, uint32_t num_many_lut) {
|
||||
|
||||
auto num_radix_blocks = lwe_array->num_radix_blocks;
|
||||
|
||||
@@ -1296,12 +1296,12 @@ void host_compute_shifted_blocks_and_states(
|
||||
2 * num_radix_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_resolve_group_carries_sequentially(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *resolved_carries,
|
||||
CudaRadixCiphertextFFI *grouping_pgns, int_radix_params params,
|
||||
int_seq_group_prop_memory<Torus> *mem, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_groups) {
|
||||
KSTorus *const *ksks, uint32_t num_groups) {
|
||||
|
||||
auto group_resolved_carries = mem->group_resolved_carries;
|
||||
if (num_groups > 1) {
|
||||
@@ -1364,11 +1364,11 @@ void host_resolve_group_carries_sequentially(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_compute_prefix_sum_hillis_steele(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *step_output,
|
||||
CudaRadixCiphertextFFI *generates_or_propagates, int_radix_lut<Torus> *luts,
|
||||
void *const *bsks, Torus *const *ksks, uint32_t num_radix_blocks) {
|
||||
void *const *bsks, KSTorus *const *ksks, uint32_t num_radix_blocks) {
|
||||
|
||||
if (step_output->lwe_dimension != generates_or_propagates->lwe_dimension)
|
||||
PANIC("Cuda error: input lwe dimensions must be the same")
|
||||
@@ -1407,11 +1407,11 @@ void host_compute_prefix_sum_hillis_steele(
|
||||
// - calculates the propagation state of each group
|
||||
// - resolves the carries between groups, either sequentially or with hillis
|
||||
// steele
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_compute_propagation_simulators_and_group_carries(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *block_states,
|
||||
int_radix_params params, int_prop_simu_group_carries_memory<Torus> *mem,
|
||||
void *const *bsks, Torus *const *ksks, uint32_t num_radix_blocks,
|
||||
void *const *bsks, KSTorus *const *ksks, uint32_t num_radix_blocks,
|
||||
uint32_t num_groups) {
|
||||
|
||||
if (num_radix_blocks > block_states->num_radix_blocks)
|
||||
@@ -1469,11 +1469,11 @@ void host_compute_propagation_simulators_and_group_carries(
|
||||
// block states: contains the propagation states for the different blocks
|
||||
// depending on the group it belongs to and the internal position within the
|
||||
// block.
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_compute_shifted_blocks_and_borrow_states(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array,
|
||||
int_shifted_blocks_and_borrow_states_memory<Torus> *mem, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t lut_stride, uint32_t num_many_lut) {
|
||||
KSTorus *const *ksks, uint32_t lut_stride, uint32_t num_many_lut) {
|
||||
auto num_radix_blocks = lwe_array->num_radix_blocks;
|
||||
|
||||
auto shifted_blocks_and_borrow_states = mem->shifted_blocks_and_borrow_states;
|
||||
@@ -1502,11 +1502,11 @@ void host_compute_shifted_blocks_and_borrow_states(
|
||||
* * (lwe_dimension + 1) * sizeof(Torus) big_lwe_vector: output of pbs should
|
||||
* have size = 2 * (glwe_dimension * polynomial_size + 1) * sizeof(Torus)
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_full_propagate_inplace(CudaStreams streams,
|
||||
CudaRadixCiphertextFFI *input_blocks,
|
||||
int_fullprop_buffer<Torus> *mem_ptr,
|
||||
Torus *const *ksks, void *const *bsks,
|
||||
KSTorus *const *ksks, void *const *bsks,
|
||||
uint32_t num_blocks) {
|
||||
auto params = mem_ptr->lut->params;
|
||||
|
||||
@@ -1664,11 +1664,11 @@ __host__ void scalar_pack_blocks(cudaStream_t stream, uint32_t gpu_index,
|
||||
* Thus, lwe_array_out must be allocated with num_radix_blocks * bits_per_block
|
||||
* * (lwe_dimension+1) * sizeeof(Torus) bytes
|
||||
*/
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void
|
||||
extract_n_bits(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
const CudaRadixCiphertextFFI *lwe_array_in, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t effective_num_radix_blocks,
|
||||
KSTorus *const *ksks, uint32_t effective_num_radix_blocks,
|
||||
uint32_t num_radix_blocks,
|
||||
int_bit_extract_luts_buffer<Torus> *bit_extract) {
|
||||
|
||||
@@ -1688,13 +1688,13 @@ extract_n_bits(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
effective_num_radix_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void
|
||||
reduce_signs(CudaStreams streams, CudaRadixCiphertextFFI *signs_array_out,
|
||||
CudaRadixCiphertextFFI *signs_array_in,
|
||||
int_comparison_buffer<Torus> *mem_ptr,
|
||||
std::function<Torus(Torus)> sign_handler_f, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_sign_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_sign_blocks) {
|
||||
|
||||
if (signs_array_out->lwe_dimension != signs_array_in->lwe_dimension)
|
||||
PANIC("Cuda error: input lwe dimensions must be the same")
|
||||
@@ -1814,11 +1814,11 @@ uint64_t scratch_cuda_apply_univariate_lut_kb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_apply_univariate_lut_kb(CudaStreams streams,
|
||||
CudaRadixCiphertextFFI *radix_lwe_out,
|
||||
CudaRadixCiphertextFFI const *radix_lwe_in,
|
||||
int_radix_lut<Torus> *mem, Torus *const *ksks,
|
||||
int_radix_lut<Torus> *mem, KSTorus *const *ksks,
|
||||
void *const *bsks) {
|
||||
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
@@ -1849,11 +1849,11 @@ uint64_t scratch_cuda_apply_many_univariate_lut_kb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_apply_many_univariate_lut_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *radix_lwe_out,
|
||||
CudaRadixCiphertextFFI const *radix_lwe_in, int_radix_lut<Torus> *mem,
|
||||
Torus *const *ksks, void *const *bsks, uint32_t num_many_lut,
|
||||
KSTorus *const *ksks, void *const *bsks, uint32_t num_many_lut,
|
||||
uint32_t lut_stride) {
|
||||
|
||||
integer_radix_apply_many_univariate_lookup_table_kb<Torus>(
|
||||
@@ -1883,12 +1883,12 @@ uint64_t scratch_cuda_apply_bivariate_lut_kb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_apply_bivariate_lut_kb(CudaStreams streams,
|
||||
CudaRadixCiphertextFFI *radix_lwe_out,
|
||||
CudaRadixCiphertextFFI const *radix_lwe_in_1,
|
||||
CudaRadixCiphertextFFI const *radix_lwe_in_2,
|
||||
int_radix_lut<Torus> *mem, Torus *const *ksks,
|
||||
int_radix_lut<Torus> *mem, KSTorus *const *ksks,
|
||||
void *const *bsks, uint32_t num_radix_blocks,
|
||||
uint32_t shift) {
|
||||
|
||||
@@ -1912,13 +1912,13 @@ uint64_t scratch_cuda_propagate_single_carry_kb_inplace(
|
||||
}
|
||||
// This function perform the three steps of Thomas' new carry propagation
|
||||
// includes the logic to extract overflow when requested
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_propagate_single_carry(CudaStreams streams,
|
||||
CudaRadixCiphertextFFI *lwe_array,
|
||||
CudaRadixCiphertextFFI *carry_out,
|
||||
const CudaRadixCiphertextFFI *input_carries,
|
||||
int_sc_prop_memory<Torus> *mem,
|
||||
void *const *bsks, Torus *const *ksks,
|
||||
void *const *bsks, KSTorus *const *ksks,
|
||||
uint32_t requested_flag, uint32_t uses_carry) {
|
||||
PUSH_RANGE("propagate sc")
|
||||
auto num_radix_blocks = lwe_array->num_radix_blocks;
|
||||
@@ -2014,12 +2014,12 @@ void host_propagate_single_carry(CudaStreams streams,
|
||||
|
||||
// This function perform the three steps of Thomas' new carry propagation
|
||||
// includes the logic to extract overflow when requested
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_add_and_propagate_single_carry(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lhs_array,
|
||||
const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out,
|
||||
const CudaRadixCiphertextFFI *input_carries, int_sc_prop_memory<Torus> *mem,
|
||||
void *const *bsks, Torus *const *ksks, uint32_t requested_flag,
|
||||
void *const *bsks, KSTorus *const *ksks, uint32_t requested_flag,
|
||||
uint32_t uses_carry) {
|
||||
PUSH_RANGE("add & propagate sc")
|
||||
if (lhs_array->num_radix_blocks != rhs_array->num_radix_blocks)
|
||||
@@ -2176,13 +2176,13 @@ uint64_t scratch_cuda_integer_overflowing_sub(
|
||||
|
||||
// This function perform the three steps of Thomas' new borrow propagation
|
||||
// includes the logic to extract overflow when requested
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_single_borrow_propagate(CudaStreams streams,
|
||||
CudaRadixCiphertextFFI *lwe_array,
|
||||
CudaRadixCiphertextFFI *overflow_block,
|
||||
const CudaRadixCiphertextFFI *input_borrow,
|
||||
int_borrow_prop_memory<Torus> *mem,
|
||||
void *const *bsks, Torus *const *ksks,
|
||||
void *const *bsks, KSTorus *const *ksks,
|
||||
uint32_t num_groups,
|
||||
uint32_t compute_overflow,
|
||||
uint32_t uses_input_borrow) {
|
||||
@@ -2308,12 +2308,12 @@ void host_single_borrow_propagate(CudaStreams streams,
|
||||
/// num_radix_blocks corresponds to the number of blocks on which to apply the
|
||||
/// LUT In scalar bitops we use a number of blocks that may be lower or equal to
|
||||
/// the input and output numbers of blocks
|
||||
template <typename InputTorus>
|
||||
template <typename InputTorus, typename KSTorus>
|
||||
__host__ void integer_radix_apply_noise_squashing_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in,
|
||||
int_noise_squashing_lut<InputTorus> *lut, void *const *bsks,
|
||||
InputTorus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
PUSH_RANGE("apply noise squashing")
|
||||
auto params = lut->params;
|
||||
|
||||
@@ -135,43 +135,43 @@ void cuda_integer_mult_radix_ciphertext_kb_64(
|
||||
case 256:
|
||||
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<256>>(
|
||||
CudaStreams(streams), radix_lwe_out, radix_lwe_left, is_bool_left,
|
||||
radix_lwe_right, is_bool_right, bsks, (uint64_t **)(ksks),
|
||||
radix_lwe_right, is_bool_right, bsks, (uint32_t **)(ksks),
|
||||
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
|
||||
break;
|
||||
case 512:
|
||||
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<512>>(
|
||||
CudaStreams(streams), radix_lwe_out, radix_lwe_left, is_bool_left,
|
||||
radix_lwe_right, is_bool_right, bsks, (uint64_t **)(ksks),
|
||||
radix_lwe_right, is_bool_right, bsks, (uint32_t **)(ksks),
|
||||
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
|
||||
break;
|
||||
case 1024:
|
||||
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<1024>>(
|
||||
CudaStreams(streams), radix_lwe_out, radix_lwe_left, is_bool_left,
|
||||
radix_lwe_right, is_bool_right, bsks, (uint64_t **)(ksks),
|
||||
radix_lwe_right, is_bool_right, bsks, (uint32_t **)(ksks),
|
||||
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
|
||||
break;
|
||||
case 2048:
|
||||
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<2048>>(
|
||||
CudaStreams(streams), radix_lwe_out, radix_lwe_left, is_bool_left,
|
||||
radix_lwe_right, is_bool_right, bsks, (uint64_t **)(ksks),
|
||||
radix_lwe_right, is_bool_right, bsks, (uint32_t **)(ksks),
|
||||
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
|
||||
break;
|
||||
case 4096:
|
||||
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<4096>>(
|
||||
CudaStreams(streams), radix_lwe_out, radix_lwe_left, is_bool_left,
|
||||
radix_lwe_right, is_bool_right, bsks, (uint64_t **)(ksks),
|
||||
radix_lwe_right, is_bool_right, bsks, (uint32_t **)(ksks),
|
||||
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
|
||||
break;
|
||||
case 8192:
|
||||
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<8192>>(
|
||||
CudaStreams(streams), radix_lwe_out, radix_lwe_left, is_bool_left,
|
||||
radix_lwe_right, is_bool_right, bsks, (uint64_t **)(ksks),
|
||||
radix_lwe_right, is_bool_right, bsks, (uint32_t **)(ksks),
|
||||
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
|
||||
break;
|
||||
case 16384:
|
||||
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<16384>>(
|
||||
CudaStreams(streams), radix_lwe_out, radix_lwe_left, is_bool_left,
|
||||
radix_lwe_right, is_bool_right, bsks, (uint64_t **)(ksks),
|
||||
radix_lwe_right, is_bool_right, bsks, (uint32_t **)(ksks),
|
||||
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
|
||||
break;
|
||||
default:
|
||||
@@ -225,7 +225,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
"output's number of radix blocks")
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t>(
|
||||
CudaStreams(streams), radix_lwe_out, radix_lwe_vec, bsks,
|
||||
(uint64_t **)(ksks), mem, radix_lwe_out->num_radix_blocks,
|
||||
(uint32_t **)(ksks), mem, radix_lwe_out->num_radix_blocks,
|
||||
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
|
||||
}
|
||||
|
||||
|
||||
@@ -290,7 +290,7 @@ __host__ uint64_t scratch_cuda_integer_partial_sum_ciphertexts_vec_kb(
|
||||
template <typename Torus>
|
||||
__host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *radix_lwe_out,
|
||||
CudaRadixCiphertextFFI *terms, void *const *bsks, uint64_t *const *ksks,
|
||||
CudaRadixCiphertextFFI *terms, void *const *bsks, uint32_t *const *ksks,
|
||||
int_sum_ciphertexts_vec_memory<uint64_t> *mem_ptr,
|
||||
uint32_t num_radix_blocks, uint32_t num_radix_in_vec) {
|
||||
auto big_lwe_dimension = mem_ptr->params.big_lwe_dimension;
|
||||
@@ -492,7 +492,7 @@ __host__ void host_integer_mult_radix_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *radix_lwe_out,
|
||||
CudaRadixCiphertextFFI const *radix_lwe_left, bool const is_bool_left,
|
||||
CudaRadixCiphertextFFI const *radix_lwe_right, bool const is_bool_right,
|
||||
void *const *bsks, uint64_t *const *ksks, int_mul_memory<Torus> *mem_ptr,
|
||||
void *const *bsks, uint32_t *const *ksks, int_mul_memory<Torus> *mem_ptr,
|
||||
uint32_t num_blocks) {
|
||||
|
||||
if (radix_lwe_out->lwe_dimension != radix_lwe_left->lwe_dimension ||
|
||||
|
||||
@@ -126,7 +126,7 @@ __host__ uint64_t scratch_cuda_integer_overflowing_sub_kb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_overflowing_sub(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI *input_left,
|
||||
@@ -134,7 +134,7 @@ __host__ void host_integer_overflowing_sub(
|
||||
CudaRadixCiphertextFFI *overflow_block,
|
||||
const CudaRadixCiphertextFFI *input_borrow,
|
||||
int_borrow_prop_memory<uint64_t> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t compute_overflow, uint32_t uses_input_borrow) {
|
||||
KSTorus *const *ksks, uint32_t compute_overflow, uint32_t uses_input_borrow) {
|
||||
PUSH_RANGE("overflowing sub")
|
||||
if (output->num_radix_blocks != input_left->num_radix_blocks ||
|
||||
output->num_radix_blocks != input_right->num_radix_blocks)
|
||||
|
||||
@@ -10,7 +10,7 @@ void cuda_scalar_bitop_integer_radix_ciphertext_kb_64(
|
||||
CudaStreams(streams), lwe_array_out, lwe_array_input,
|
||||
static_cast<const uint64_t *>(clear_blocks),
|
||||
static_cast<const uint64_t *>(h_clear_blocks), num_clear_blocks,
|
||||
(int_bitop_buffer<uint64_t> *)mem_ptr, bsks, (uint64_t **)(ksks));
|
||||
(int_bitop_buffer<uint64_t> *)mem_ptr, bsks, (uint32_t **)(ksks));
|
||||
}
|
||||
|
||||
void update_degrees_after_scalar_bitand(uint64_t *output_degrees,
|
||||
|
||||
@@ -4,12 +4,12 @@
|
||||
#include "integer/bitwise_ops.cuh"
|
||||
#include <omp.h>
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_scalar_bitop_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *input, Torus const *clear_blocks,
|
||||
Torus const *h_clear_blocks, uint32_t num_clear_blocks,
|
||||
int_bitop_buffer<Torus> *mem_ptr, void *const *bsks, Torus *const *ksks) {
|
||||
int_bitop_buffer<Torus> *mem_ptr, void *const *bsks, KSTorus *const *ksks) {
|
||||
|
||||
if (output->num_radix_blocks != input->num_radix_blocks)
|
||||
PANIC("Cuda error: input and output num radix blocks must be equal")
|
||||
|
||||
@@ -49,7 +49,7 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
|
||||
host_integer_radix_scalar_equality_check_kb<uint64_t>(
|
||||
CudaStreams(streams), lwe_array_out, lwe_array_in,
|
||||
static_cast<const uint64_t *>(scalar_blocks), buffer, bsks,
|
||||
(uint64_t **)(ksks), num_radix_blocks, num_scalar_blocks);
|
||||
(uint32_t **)(ksks), num_radix_blocks, num_scalar_blocks);
|
||||
break;
|
||||
case GT:
|
||||
case GE:
|
||||
@@ -62,7 +62,7 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
|
||||
CudaStreams(streams), lwe_array_out, lwe_array_in,
|
||||
static_cast<const uint64_t *>(scalar_blocks),
|
||||
static_cast<const uint64_t *>(h_scalar_blocks), buffer,
|
||||
buffer->diff_buffer->operator_f, bsks, (uint64_t **)(ksks),
|
||||
buffer->diff_buffer->operator_f, bsks, (uint32_t **)(ksks),
|
||||
num_radix_blocks, num_scalar_blocks);
|
||||
break;
|
||||
case MAX:
|
||||
@@ -74,7 +74,7 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
|
||||
CudaStreams(streams), lwe_array_out, lwe_array_in,
|
||||
static_cast<const uint64_t *>(scalar_blocks),
|
||||
static_cast<const uint64_t *>(h_scalar_blocks), buffer, bsks,
|
||||
(uint64_t **)(ksks), num_radix_blocks, num_scalar_blocks);
|
||||
(uint32_t **)(ksks), num_radix_blocks, num_scalar_blocks);
|
||||
break;
|
||||
default:
|
||||
PANIC("Cuda error: integer operation not supported")
|
||||
|
||||
@@ -24,12 +24,12 @@ Torus is_x_less_than_y_given_input_borrow(Torus last_x_block,
|
||||
return output_sign_bit ^ overflow_flag;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void scalar_compare_radix_blocks_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI *lwe_array_in, Torus *scalar_blocks,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks) {
|
||||
|
||||
if (num_radix_blocks == 0)
|
||||
return;
|
||||
@@ -82,13 +82,13 @@ __host__ void scalar_compare_radix_blocks_kb(
|
||||
streams, lwe_array_out, message_modulus, carry_modulus);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void integer_radix_unsigned_scalar_difference_check_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in, Torus const *scalar_blocks,
|
||||
Torus const *h_scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
|
||||
std::function<Torus(Torus)> sign_handler_f, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
|
||||
PANIC("Cuda error: input lwe dimensions must be the same")
|
||||
if (lwe_array_in->num_radix_blocks < num_radix_blocks)
|
||||
@@ -320,13 +320,13 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void integer_radix_signed_scalar_difference_check_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in, Torus const *scalar_blocks,
|
||||
Torus const *h_scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
|
||||
std::function<Torus(Torus)> sign_handler_f, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
|
||||
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
|
||||
PANIC("Cuda error: input lwe dimensions must be the same")
|
||||
@@ -638,13 +638,13 @@ __host__ void integer_radix_signed_scalar_difference_check_kb(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_scalar_difference_check_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in, Torus const *scalar_blocks,
|
||||
Torus const *h_scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
|
||||
std::function<Torus(Torus)> sign_handler_f, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
|
||||
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
|
||||
PANIC("Cuda error: input lwe dimensions must be the same")
|
||||
@@ -666,12 +666,12 @@ __host__ void host_integer_radix_scalar_difference_check_kb(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_scalar_maxmin_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in, Torus const *scalar_blocks,
|
||||
Torus const *h_scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
|
||||
void *const *bsks, Torus *const *ksks, uint32_t num_radix_blocks,
|
||||
void *const *bsks, KSTorus *const *ksks, uint32_t num_radix_blocks,
|
||||
uint32_t num_scalar_blocks) {
|
||||
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
|
||||
@@ -709,12 +709,12 @@ __host__ void host_integer_radix_scalar_maxmin_kb(
|
||||
lwe_array_right, mem_ptr->cmux_buffer, bsks, ksks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_scalar_equality_check_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in, Torus const *scalar_blocks,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
|
||||
|
||||
if (lwe_array_out->lwe_dimension != lwe_array_in->lwe_dimension)
|
||||
PANIC("Cuda error: input and output lwe dimensions must be the same")
|
||||
|
||||
@@ -28,7 +28,7 @@ void cuda_integer_unsigned_scalar_div_radix_kb_64(
|
||||
|
||||
host_integer_unsigned_scalar_div_radix<uint64_t>(
|
||||
CudaStreams(streams), numerator_ct,
|
||||
(int_unsigned_scalar_div_mem<uint64_t> *)mem_ptr, bsks, (uint64_t **)ksks,
|
||||
(int_unsigned_scalar_div_mem<uint64_t> *)mem_ptr, bsks, (uint32_t **)ksks,
|
||||
scalar_divisor_ffi);
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ void cuda_integer_signed_scalar_div_radix_kb_64(
|
||||
|
||||
host_integer_signed_scalar_div_radix_kb<uint64_t>(
|
||||
CudaStreams(streams), numerator_ct,
|
||||
(int_signed_scalar_div_mem<uint64_t> *)mem_ptr, bsks, (uint64_t **)ksks,
|
||||
(int_signed_scalar_div_mem<uint64_t> *)mem_ptr, bsks, (uint32_t **)ksks,
|
||||
scalar_divisor_ffi, numerator_bits);
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ void cuda_integer_unsigned_scalar_div_rem_radix_kb_64(
|
||||
host_integer_unsigned_scalar_div_rem_radix<uint64_t>(
|
||||
CudaStreams(streams), quotient_ct, remainder_ct,
|
||||
(int_unsigned_scalar_div_rem_buffer<uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)ksks, scalar_divisor_ffi, divisor_has_at_least_one_set,
|
||||
(uint32_t **)ksks, scalar_divisor_ffi, divisor_has_at_least_one_set,
|
||||
decomposed_divisor, num_scalars_divisor, (uint64_t *)clear_blocks,
|
||||
(uint64_t *)h_clear_blocks, num_clear_blocks);
|
||||
}
|
||||
@@ -172,7 +172,7 @@ void cuda_integer_signed_scalar_div_rem_radix_kb_64(
|
||||
host_integer_signed_scalar_div_rem_radix<uint64_t>(
|
||||
CudaStreams(streams), quotient_ct, remainder_ct,
|
||||
(int_signed_scalar_div_rem_buffer<uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)ksks, scalar_divisor_ffi, divisor_has_at_least_one_set,
|
||||
(uint32_t **)ksks, scalar_divisor_ffi, divisor_has_at_least_one_set,
|
||||
decomposed_divisor, num_scalars_divisor, numerator_bits);
|
||||
}
|
||||
|
||||
|
||||
@@ -23,11 +23,11 @@ __host__ uint64_t scratch_integer_unsigned_scalar_div_radix(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_unsigned_scalar_div_radix(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *numerator_ct,
|
||||
int_unsigned_scalar_div_mem<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi) {
|
||||
KSTorus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi) {
|
||||
|
||||
if (scalar_divisor_ffi->is_abs_divisor_one) {
|
||||
return;
|
||||
@@ -117,11 +117,11 @@ __host__ uint64_t scratch_integer_signed_scalar_div_radix_kb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_signed_scalar_div_radix_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *numerator_ct,
|
||||
int_signed_scalar_div_mem<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
|
||||
KSTorus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
|
||||
uint32_t numerator_bits) {
|
||||
|
||||
if (scalar_divisor_ffi->is_abs_divisor_one) {
|
||||
@@ -246,12 +246,12 @@ __host__ uint64_t scratch_integer_unsigned_scalar_div_rem_radix(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_unsigned_scalar_div_rem_radix(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *quotient_ct,
|
||||
CudaRadixCiphertextFFI *remainder_ct,
|
||||
int_unsigned_scalar_div_rem_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
|
||||
KSTorus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
|
||||
uint64_t const *divisor_has_at_least_one_set,
|
||||
uint64_t const *decomposed_divisor, uint32_t const num_scalars_divisor,
|
||||
Torus const *clear_blocks, Torus const *h_clear_blocks,
|
||||
@@ -313,12 +313,12 @@ __host__ uint64_t scratch_integer_signed_scalar_div_rem_radix(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_signed_scalar_div_rem_radix(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *quotient_ct,
|
||||
CudaRadixCiphertextFFI *remainder_ct,
|
||||
int_signed_scalar_div_rem_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
|
||||
KSTorus *const *ksks, const CudaScalarDivisorFFI *scalar_divisor_ffi,
|
||||
uint64_t const *divisor_has_at_least_one_set,
|
||||
uint64_t const *decomposed_divisor, uint32_t const num_scalars_divisor,
|
||||
uint32_t numerator_bits) {
|
||||
|
||||
@@ -28,7 +28,7 @@ void cuda_scalar_multiplication_integer_radix_ciphertext_64_inplace(
|
||||
host_integer_scalar_mul_radix<uint64_t>(
|
||||
CudaStreams(streams), lwe_array, decomposed_scalar, has_at_least_one_set,
|
||||
reinterpret_cast<int_scalar_mul_buffer<uint64_t> *>(mem), bsks,
|
||||
(uint64_t **)(ksks), message_modulus, num_scalars);
|
||||
(uint32_t **)(ksks), message_modulus, num_scalars);
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_radix_scalar_mul(CudaStreamsFFI streams,
|
||||
|
||||
@@ -41,11 +41,11 @@ __host__ uint64_t scratch_cuda_integer_radix_scalar_mul_kb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename KSTorus>
|
||||
__host__ void host_integer_scalar_mul_radix(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array,
|
||||
T const *decomposed_scalar, T const *has_at_least_one_set,
|
||||
int_scalar_mul_buffer<T> *mem, void *const *bsks, T *const *ksks,
|
||||
int_scalar_mul_buffer<T> *mem, void *const *bsks, KSTorus *const *ksks,
|
||||
uint32_t message_modulus, uint32_t num_scalars) {
|
||||
|
||||
auto num_radix_blocks = lwe_array->num_radix_blocks;
|
||||
@@ -164,10 +164,10 @@ __host__ void host_integer_small_scalar_mul_radix(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_scalar_mul_high_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *ct,
|
||||
int_scalar_mul_high_buffer<Torus> *mem_ptr, Torus *const *ksks,
|
||||
int_scalar_mul_high_buffer<Torus> *mem_ptr, KSTorus *const *ksks,
|
||||
void *const *bsks, const CudaScalarDivisorFFI *scalar_divisor_ffi) {
|
||||
|
||||
if (scalar_divisor_ffi->is_chosen_multiplier_zero) {
|
||||
@@ -187,7 +187,7 @@ __host__ void host_integer_radix_scalar_mul_high_kb(
|
||||
if (scalar_divisor_ffi->is_chosen_multiplier_pow2) {
|
||||
host_integer_radix_logical_scalar_shift_kb_inplace<Torus>(
|
||||
streams, tmp_ffi, scalar_divisor_ffi->ilog2_chosen_multiplier,
|
||||
mem_ptr->logical_scalar_shift_mem, bsks, (uint64_t **)ksks,
|
||||
mem_ptr->logical_scalar_shift_mem, bsks, (uint32_t **)ksks,
|
||||
tmp_ffi->num_radix_blocks);
|
||||
|
||||
} else {
|
||||
@@ -195,7 +195,7 @@ __host__ void host_integer_radix_scalar_mul_high_kb(
|
||||
host_integer_scalar_mul_radix<Torus>(
|
||||
streams, tmp_ffi, scalar_divisor_ffi->decomposed_chosen_multiplier,
|
||||
scalar_divisor_ffi->chosen_multiplier_has_at_least_one_set,
|
||||
mem_ptr->scalar_mul_mem, bsks, (uint64_t **)ksks,
|
||||
mem_ptr->scalar_mul_mem, bsks, (uint32_t **)ksks,
|
||||
mem_ptr->params.message_modulus, scalar_divisor_ffi->num_scalars);
|
||||
}
|
||||
}
|
||||
@@ -203,10 +203,10 @@ __host__ void host_integer_radix_scalar_mul_high_kb(
|
||||
host_trim_radix_blocks_lsb<Torus>(ct, tmp_ffi, streams);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_signed_scalar_mul_high_kb(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *ct,
|
||||
int_signed_scalar_mul_high_buffer<Torus> *mem_ptr, Torus *const *ksks,
|
||||
int_signed_scalar_mul_high_buffer<Torus> *mem_ptr, KSTorus *const *ksks,
|
||||
const CudaScalarDivisorFFI *scalar_divisor_ffi, void *const *bsks) {
|
||||
|
||||
if (scalar_divisor_ffi->is_chosen_multiplier_zero) {
|
||||
@@ -219,7 +219,7 @@ __host__ void host_integer_radix_signed_scalar_mul_high_kb(
|
||||
|
||||
host_extend_radix_with_sign_msb<Torus>(
|
||||
streams, tmp_ffi, ct, mem_ptr->extend_radix_mem, ct->num_radix_blocks,
|
||||
bsks, (uint64_t **)ksks);
|
||||
bsks, (uint32_t **)ksks);
|
||||
|
||||
if (scalar_divisor_ffi->active_bits != (uint32_t)0 &&
|
||||
!scalar_divisor_ffi->is_abs_chosen_multiplier_one &&
|
||||
@@ -228,13 +228,13 @@ __host__ void host_integer_radix_signed_scalar_mul_high_kb(
|
||||
if (scalar_divisor_ffi->is_chosen_multiplier_pow2) {
|
||||
host_integer_radix_logical_scalar_shift_kb_inplace<Torus>(
|
||||
streams, tmp_ffi, scalar_divisor_ffi->ilog2_chosen_multiplier,
|
||||
mem_ptr->logical_scalar_shift_mem, bsks, (uint64_t **)ksks,
|
||||
mem_ptr->logical_scalar_shift_mem, bsks, (uint32_t **)ksks,
|
||||
tmp_ffi->num_radix_blocks);
|
||||
} else {
|
||||
host_integer_scalar_mul_radix<Torus>(
|
||||
streams, tmp_ffi, scalar_divisor_ffi->decomposed_chosen_multiplier,
|
||||
scalar_divisor_ffi->chosen_multiplier_has_at_least_one_set,
|
||||
mem_ptr->scalar_mul_mem, bsks, (uint64_t **)ksks,
|
||||
mem_ptr->scalar_mul_mem, bsks, (uint32_t **)ksks,
|
||||
mem_ptr->params.message_modulus, scalar_divisor_ffi->num_scalars);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ void cuda_integer_radix_scalar_rotate_kb_64_inplace(
|
||||
host_integer_radix_scalar_rotate_kb_inplace<uint64_t>(
|
||||
CudaStreams(streams), lwe_array, n,
|
||||
(int_logical_scalar_shift_buffer<uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)(ksks));
|
||||
(uint32_t **)(ksks));
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_radix_scalar_rotate(CudaStreamsFFI streams,
|
||||
|
||||
@@ -24,11 +24,11 @@ __host__ uint64_t scratch_cuda_integer_radix_scalar_rotate_kb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_scalar_rotate_kb_inplace(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array, uint32_t n,
|
||||
int_logical_scalar_shift_buffer<Torus> *mem, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
auto num_blocks = lwe_array->num_radix_blocks;
|
||||
auto params = mem->params;
|
||||
|
||||
@@ -31,7 +31,7 @@ void cuda_integer_radix_logical_scalar_shift_kb_64_inplace(
|
||||
host_integer_radix_logical_scalar_shift_kb_inplace<uint64_t>(
|
||||
CudaStreams(streams), lwe_array, shift,
|
||||
(int_logical_scalar_shift_buffer<uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)(ksks), lwe_array->num_radix_blocks);
|
||||
(uint32_t **)(ksks), lwe_array->num_radix_blocks);
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_integer_radix_arithmetic_scalar_shift_kb_64(
|
||||
@@ -68,7 +68,7 @@ void cuda_integer_radix_arithmetic_scalar_shift_kb_64_inplace(
|
||||
host_integer_radix_arithmetic_scalar_shift_kb_inplace<uint64_t>(
|
||||
CudaStreams(streams), lwe_array, shift,
|
||||
(int_arithmetic_scalar_shift_buffer<uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)(ksks));
|
||||
(uint32_t **)(ksks));
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_radix_logical_scalar_shift(CudaStreamsFFI streams,
|
||||
|
||||
@@ -24,11 +24,11 @@ __host__ uint64_t scratch_cuda_integer_radix_logical_scalar_shift_kb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_logical_scalar_shift_kb_inplace(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array, uint32_t shift,
|
||||
int_logical_scalar_shift_buffer<Torus> *mem, void *const *bsks,
|
||||
Torus *const *ksks, uint32_t num_blocks) {
|
||||
KSTorus *const *ksks, uint32_t num_blocks) {
|
||||
|
||||
if (lwe_array->num_radix_blocks < num_blocks)
|
||||
PANIC("Cuda error: input does not have enough blocks")
|
||||
@@ -128,11 +128,11 @@ __host__ uint64_t scratch_cuda_integer_radix_arithmetic_scalar_shift_kb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array, uint32_t shift,
|
||||
int_arithmetic_scalar_shift_buffer<Torus> *mem, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
|
||||
auto num_blocks = lwe_array->num_radix_blocks;
|
||||
auto params = mem->params;
|
||||
|
||||
@@ -27,7 +27,7 @@ void cuda_integer_radix_shift_and_rotate_kb_64_inplace(
|
||||
host_integer_radix_shift_and_rotate_kb_inplace<uint64_t>(
|
||||
CudaStreams(streams), lwe_array, lwe_shift,
|
||||
(int_shift_and_rotate_buffer<uint64_t> *)mem_ptr, bsks,
|
||||
(uint64_t **)(ksks));
|
||||
(uint32_t **)(ksks));
|
||||
}
|
||||
|
||||
void cleanup_cuda_integer_radix_shift_and_rotate(CudaStreamsFFI streams,
|
||||
|
||||
@@ -24,12 +24,12 @@ __host__ uint64_t scratch_cuda_integer_radix_shift_and_rotate_kb(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
__host__ void host_integer_radix_shift_and_rotate_kb_inplace(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array,
|
||||
CudaRadixCiphertextFFI const *lwe_shift,
|
||||
int_shift_and_rotate_buffer<Torus> *mem, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
KSTorus *const *ksks) {
|
||||
cuda_set_device(streams.gpu_index(0));
|
||||
|
||||
if (lwe_array->num_radix_blocks != lwe_shift->num_radix_blocks)
|
||||
|
||||
@@ -27,7 +27,7 @@ void cuda_sub_and_propagate_single_carry_kb_64_inplace(
|
||||
PUSH_RANGE("sub")
|
||||
host_sub_and_propagate_single_carry<uint64_t>(
|
||||
CudaStreams(streams), lhs_array, rhs_array, carry_out, carry_in,
|
||||
(int_sub_and_propagate<uint64_t> *)mem_ptr, bsks, (uint64_t **)(ksks),
|
||||
(int_sub_and_propagate<uint64_t> *)mem_ptr, bsks, (uint32_t **)(ksks),
|
||||
requested_flag, uses_carry);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
@@ -27,12 +27,12 @@ uint64_t scratch_cuda_sub_and_propagate_single_carry(
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename Torus, typename KSTorus>
|
||||
void host_sub_and_propagate_single_carry(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lhs_array,
|
||||
const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out,
|
||||
const CudaRadixCiphertextFFI *input_carries,
|
||||
int_sub_and_propagate<Torus> *mem, void *const *bsks, Torus *const *ksks,
|
||||
int_sub_and_propagate<Torus> *mem, void *const *bsks, KSTorus *const *ksks,
|
||||
uint32_t requested_flag, uint32_t uses_carry) {
|
||||
|
||||
host_integer_radix_negation<Torus>(
|
||||
|
||||
@@ -9,6 +9,9 @@ use crate::core_crypto::prelude::{
|
||||
lwe_keyswitch_key_input_key_element_encrypted_size, LweKeyswitchKeyOwned, LweSize,
|
||||
UnsignedInteger,
|
||||
};
|
||||
use crate::prelude::{CastFrom, CastInto};
|
||||
use itertools::Itertools;
|
||||
use std::any::TypeId;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
@@ -22,12 +25,18 @@ pub struct CudaLweKeyswitchKey<T: UnsignedInteger> {
|
||||
}
|
||||
|
||||
impl<T: UnsignedInteger> CudaLweKeyswitchKey<T> {
|
||||
pub fn from_lwe_keyswitch_key(h_ksk: &LweKeyswitchKeyOwned<T>, streams: &CudaStreams) -> Self {
|
||||
pub fn from_lwe_keyswitch_key<O: UnsignedInteger>(
|
||||
h_ksk: &LweKeyswitchKeyOwned<O>,
|
||||
streams: &CudaStreams,
|
||||
) -> Self
|
||||
where
|
||||
O: CastInto<T>,
|
||||
{
|
||||
let decomp_base_log = h_ksk.decomposition_base_log();
|
||||
let decomp_level_count = h_ksk.decomposition_level_count();
|
||||
let input_lwe_size = h_ksk.input_key_lwe_dimension().to_lwe_size();
|
||||
let output_lwe_size = h_ksk.output_key_lwe_dimension().to_lwe_size();
|
||||
let ciphertext_modulus = h_ksk.ciphertext_modulus();
|
||||
let ciphertext_modulus = CiphertextModulus::<T>::new_native(); //h_ksk.ciphertext_modulus().try_to().unwrap();
|
||||
|
||||
// Allocate memory
|
||||
let mut d_vec = CudaVec::<T>::new_multi_gpu(
|
||||
@@ -39,8 +48,26 @@ impl<T: UnsignedInteger> CudaLweKeyswitchKey<T> {
|
||||
streams,
|
||||
);
|
||||
|
||||
unsafe {
|
||||
convert_lwe_keyswitch_key_async(streams, &mut d_vec, h_ksk.as_ref());
|
||||
if TypeId::of::<T>() == TypeId::of::<O>() {
|
||||
panic!("Forced KSK to u32 not working!");
|
||||
unsafe {
|
||||
let casted = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
h_ksk.as_ref().as_ptr() as *const T,
|
||||
h_ksk.as_ref().len(),
|
||||
)
|
||||
};
|
||||
convert_lwe_keyswitch_key_async(streams, &mut d_vec, casted);
|
||||
}
|
||||
} else {
|
||||
let dcast: Vec<T> = h_ksk
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|v| (*v).cast_into())
|
||||
.collect_vec();
|
||||
unsafe {
|
||||
d_vec.copy_from_cpu_multi_gpu_async(dcast.as_slice(), streams);
|
||||
}
|
||||
}
|
||||
|
||||
streams.synchronize();
|
||||
|
||||
@@ -377,7 +377,7 @@ pub unsafe fn unchecked_scalar_mul_integer_radix_kb_async<T: UnsignedInteger, B:
|
||||
decomposed_scalar: &[T],
|
||||
has_at_least_one_set: &[T],
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<u64>,
|
||||
keyswitch_key: &CudaVec<u32>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
@@ -1523,7 +1523,7 @@ pub unsafe fn unchecked_scalar_bitop_integer_radix_kb_assign_async<
|
||||
clear_blocks: &CudaVec<T>,
|
||||
h_clear_blocks: &[T],
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<u32>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
@@ -1888,7 +1888,7 @@ pub unsafe fn unchecked_scalar_comparison_integer_radix_kb_async<T: UnsignedInte
|
||||
scalar_blocks: &CudaVec<T>,
|
||||
h_scalar_blocks: &[T],
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<u32>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
@@ -5612,7 +5612,7 @@ pub unsafe fn apply_univariate_lut_kb_async<T: UnsignedInteger, B: Numeric>(
|
||||
input_lut: &[T],
|
||||
lut_degree: u64,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<u32>,
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
@@ -5718,7 +5718,7 @@ pub unsafe fn apply_many_univariate_lut_kb_async<T: UnsignedInteger, B: Numeric>
|
||||
input_lut: &[T],
|
||||
lut_degree: u64,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<u32>,
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
@@ -5829,7 +5829,7 @@ pub unsafe fn apply_bivariate_lut_kb_async<T: UnsignedInteger, B: Numeric>(
|
||||
input_lut: &[T],
|
||||
lut_degree: u64,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<u32>,
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
@@ -6422,7 +6422,7 @@ pub unsafe fn compute_prefix_sum_hillis_steele_async<T: UnsignedInteger, B: Nume
|
||||
input_lut: &[T],
|
||||
lut_degree: u64,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<u32>,
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
@@ -7186,7 +7186,7 @@ pub unsafe fn noise_squashing_async<T: UnsignedInteger, B: Numeric>(
|
||||
output_noise_levels: &mut Vec<u64>,
|
||||
input: &CudaSlice<u64>,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<u64>,
|
||||
keyswitch_key: &CudaVec<u32>,
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
@@ -7302,7 +7302,7 @@ pub unsafe fn expand_async<T: UnsignedInteger, B: Numeric>(
|
||||
lwe_array_out: &mut CudaLweCiphertextList<T>,
|
||||
lwe_flattened_compact_array_in: &CudaVec<T>,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
computing_ks_key: &CudaVec<T>,
|
||||
computing_ks_key: &CudaVec<u32>,
|
||||
casting_key: &CudaVec<T>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
|
||||
@@ -42,7 +42,7 @@ impl<Scalar: UnsignedInteger> CudaBootstrappingKey<Scalar> {
|
||||
/// sends it to the server so it can compute homomorphic circuits.
|
||||
// #[derive(PartialEq, Serialize, Deserialize)]
|
||||
pub struct CudaServerKey {
|
||||
pub key_switching_key: CudaLweKeyswitchKey<u64>,
|
||||
pub key_switching_key: CudaLweKeyswitchKey<u32>,
|
||||
pub bootstrapping_key: CudaBootstrappingKey<u64>,
|
||||
// Size of the message buffer
|
||||
pub message_modulus: MessageModulus,
|
||||
|
||||
Reference in New Issue
Block a user