mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
8 Commits
tm/flip
...
al/ci_fixe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a01814e991 | ||
|
|
b5d50cec5a | ||
|
|
c619eb479e | ||
|
|
4e8bdc4380 | ||
|
|
1deaaf5249 | ||
|
|
abd2fe1f4e | ||
|
|
47d671b043 | ||
|
|
f700016776 |
4
.github/workflows/gpu_fast_h100_tests.yml
vendored
4
.github/workflows/gpu_fast_h100_tests.yml
vendored
@@ -146,8 +146,8 @@ jobs:
|
||||
|
||||
- name: Run core crypto and internal CUDA backend tests
|
||||
run: |
|
||||
BIG_TESTS_INSTANCE=TRUE make test_core_crypto_gpu
|
||||
BIG_TESTS_INSTANCE=TRUE make test_integer_compression_gpu
|
||||
BIG_TESTS_INSTANCE=FALSE make test_core_crypto_gpu
|
||||
BIG_TESTS_INSTANCE=FALSE make test_integer_compression_gpu
|
||||
BIG_TESTS_INSTANCE=TRUE make test_cuda_backend
|
||||
|
||||
- name: Run user docs tests
|
||||
|
||||
@@ -386,8 +386,8 @@ void scratch_cuda_integer_compute_prefix_sum_hillis_steele_64(
|
||||
|
||||
void cuda_integer_compute_prefix_sum_hillis_steele_64(
|
||||
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count,
|
||||
void *output_radix_lwe, void *input_radix_lwe, int8_t *mem_ptr, void **ksks,
|
||||
void **bsks, uint32_t num_blocks, uint32_t shift);
|
||||
void *output_radix_lwe, void *generates_or_propagates, int8_t *mem_ptr,
|
||||
void **ksks, void **bsks, uint32_t num_blocks, uint32_t shift);
|
||||
|
||||
void cleanup_cuda_integer_compute_prefix_sum_hillis_steele_64(
|
||||
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count,
|
||||
@@ -1357,6 +1357,7 @@ template <typename Torus> struct int_overflowing_sub_memory {
|
||||
|
||||
template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
Torus *new_blocks;
|
||||
Torus *new_blocks_copy;
|
||||
Torus *old_blocks;
|
||||
Torus *small_lwe_vector;
|
||||
int_radix_params params;
|
||||
@@ -1384,6 +1385,9 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
new_blocks = (Torus *)cuda_malloc_async(
|
||||
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
|
||||
streams[0], gpu_indexes[0]);
|
||||
new_blocks_copy = (Torus *)cuda_malloc_async(
|
||||
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
|
||||
streams[0], gpu_indexes[0]);
|
||||
old_blocks = (Torus *)cuda_malloc_async(
|
||||
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
|
||||
streams[0], gpu_indexes[0]);
|
||||
@@ -1415,6 +1419,9 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
this->new_blocks = new_blocks;
|
||||
this->old_blocks = old_blocks;
|
||||
this->small_lwe_vector = small_lwe_vector;
|
||||
new_blocks_copy = (Torus *)cuda_malloc_async(
|
||||
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
|
||||
streams[0], gpu_indexes[0]);
|
||||
|
||||
d_smart_copy_in = (int32_t *)cuda_malloc_async(
|
||||
max_pbs_count * sizeof(int32_t), streams[0], gpu_indexes[0]);
|
||||
@@ -1433,8 +1440,8 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
cuda_drop_async(small_lwe_vector, streams[0], gpu_indexes[0]);
|
||||
}
|
||||
|
||||
cuda_drop_async(new_blocks_copy, streams[0], gpu_indexes[0]);
|
||||
scp_mem->release(streams, gpu_indexes, gpu_count);
|
||||
|
||||
delete scp_mem;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -81,14 +81,6 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_64(
|
||||
|
||||
void cleanup_cuda_programmable_bootstrap(void *stream, uint32_t gpu_index,
|
||||
int8_t **pbs_buffer);
|
||||
|
||||
uint64_t get_buffer_size_programmable_bootstrap_amortized_64(
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t input_lwe_ciphertext_count);
|
||||
|
||||
uint64_t get_buffer_size_programmable_bootstrap_64(
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t input_lwe_ciphertext_count);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
|
||||
@@ -195,15 +195,15 @@ void scratch_cuda_integer_compute_prefix_sum_hillis_steele_64(
|
||||
|
||||
void cuda_integer_compute_prefix_sum_hillis_steele_64(
|
||||
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count,
|
||||
void *output_radix_lwe, void *input_radix_lwe, int8_t *mem_ptr, void **ksks,
|
||||
void **bsks, uint32_t num_blocks, uint32_t shift) {
|
||||
void *output_radix_lwe, void *generates_or_propagates, int8_t *mem_ptr,
|
||||
void **ksks, void **bsks, uint32_t num_blocks, uint32_t shift) {
|
||||
|
||||
int_radix_params params = ((int_radix_lut<uint64_t> *)mem_ptr)->params;
|
||||
|
||||
host_compute_prefix_sum_hillis_steele<uint64_t>(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
|
||||
static_cast<uint64_t *>(output_radix_lwe),
|
||||
static_cast<uint64_t *>(input_radix_lwe), params,
|
||||
static_cast<uint64_t *>(generates_or_propagates), params,
|
||||
(int_radix_lut<uint64_t> *)mem_ptr, bsks, (uint64_t **)(ksks),
|
||||
num_blocks);
|
||||
}
|
||||
|
||||
@@ -241,7 +241,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
|
||||
static_cast<uint64_t *>(radix_lwe_out),
|
||||
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
|
||||
nullptr);
|
||||
break;
|
||||
case 1024:
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
|
||||
@@ -249,7 +250,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
|
||||
static_cast<uint64_t *>(radix_lwe_out),
|
||||
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
|
||||
nullptr);
|
||||
break;
|
||||
case 2048:
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
|
||||
@@ -257,7 +259,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
|
||||
static_cast<uint64_t *>(radix_lwe_out),
|
||||
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
|
||||
nullptr);
|
||||
break;
|
||||
case 4096:
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
|
||||
@@ -265,7 +268,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
|
||||
static_cast<uint64_t *>(radix_lwe_out),
|
||||
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
|
||||
nullptr);
|
||||
break;
|
||||
case 8192:
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
|
||||
@@ -273,7 +277,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
|
||||
static_cast<uint64_t *>(radix_lwe_out),
|
||||
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
|
||||
nullptr);
|
||||
break;
|
||||
case 16384:
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
|
||||
@@ -281,7 +286,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
|
||||
static_cast<uint64_t *>(radix_lwe_out),
|
||||
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
|
||||
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
|
||||
nullptr);
|
||||
break;
|
||||
default:
|
||||
PANIC("Cuda error (integer multiplication): unsupported polynomial size. "
|
||||
|
||||
@@ -186,9 +186,10 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
Torus *radix_lwe_out, Torus *terms, int *terms_degree, void **bsks,
|
||||
uint64_t **ksks, int_sum_ciphertexts_vec_memory<uint64_t> *mem_ptr,
|
||||
uint32_t num_blocks_in_radix, uint32_t num_radix_in_vec,
|
||||
int_radix_lut<Torus> *reused_lut = nullptr) {
|
||||
int_radix_lut<Torus> *reused_lut) {
|
||||
|
||||
auto new_blocks = mem_ptr->new_blocks;
|
||||
auto new_blocks_copy = mem_ptr->new_blocks_copy;
|
||||
auto old_blocks = mem_ptr->old_blocks;
|
||||
auto small_lwe_vector = mem_ptr->small_lwe_vector;
|
||||
|
||||
@@ -205,12 +206,27 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
auto small_lwe_dimension = mem_ptr->params.small_lwe_dimension;
|
||||
auto small_lwe_size = small_lwe_dimension + 1;
|
||||
|
||||
if (num_radix_in_vec == 0)
|
||||
return;
|
||||
if (num_radix_in_vec == 1) {
|
||||
cuda_memcpy_async_gpu_to_gpu(radix_lwe_out, terms,
|
||||
num_blocks_in_radix * big_lwe_size *
|
||||
sizeof(Torus),
|
||||
streams[0], gpu_indexes[0]);
|
||||
return;
|
||||
}
|
||||
if (old_blocks != terms) {
|
||||
cuda_memcpy_async_gpu_to_gpu(old_blocks, terms,
|
||||
num_blocks_in_radix * num_radix_in_vec *
|
||||
big_lwe_size * sizeof(Torus),
|
||||
streams[0], gpu_indexes[0]);
|
||||
}
|
||||
if (num_radix_in_vec == 2) {
|
||||
host_addition<Torus>(streams[0], gpu_indexes[0], radix_lwe_out, old_blocks,
|
||||
&old_blocks[num_blocks * big_lwe_size],
|
||||
big_lwe_dimension, num_blocks);
|
||||
return;
|
||||
}
|
||||
|
||||
size_t r = num_radix_in_vec;
|
||||
size_t total_modulus = message_modulus * carry_modulus;
|
||||
@@ -287,7 +303,6 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
terms_degree, h_lwe_idx_in, h_lwe_idx_out, h_smart_copy_in,
|
||||
h_smart_copy_out, ch_amount, r, num_blocks, chunk_size, message_max,
|
||||
total_count, message_count, carry_count, sm_copy_count);
|
||||
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
|
||||
auto lwe_indexes_in = luts_message_carry->lwe_indexes_in;
|
||||
auto lwe_indexes_out = luts_message_carry->lwe_indexes_out;
|
||||
luts_message_carry->set_lwe_indexes(streams[0], gpu_indexes[0],
|
||||
@@ -302,8 +317,11 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
// inside d_smart_copy_in there are only -1 values
|
||||
// it's fine to call smart_copy with same pointer
|
||||
// as source and destination
|
||||
cuda_memcpy_async_gpu_to_gpu(new_blocks_copy, new_blocks,
|
||||
r * num_blocks * big_lwe_size * sizeof(Torus),
|
||||
streams[0], gpu_indexes[0]);
|
||||
smart_copy<Torus><<<sm_copy_count, 1024, 0, streams[0]>>>(
|
||||
new_blocks, new_blocks, d_smart_copy_out, d_smart_copy_in,
|
||||
new_blocks, new_blocks_copy, d_smart_copy_out, d_smart_copy_in,
|
||||
big_lwe_size);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
|
||||
@@ -91,7 +91,6 @@ __host__ void host_integer_scalar_mul_radix(
|
||||
j++;
|
||||
}
|
||||
}
|
||||
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
|
||||
|
||||
cuda_drop_async(preshifted_buffer, streams[0], gpu_indexes[0]);
|
||||
mem->logical_scalar_shift_buffer->release(streams, gpu_indexes, gpu_count);
|
||||
@@ -109,7 +108,7 @@ __host__ void host_integer_scalar_mul_radix(
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<T, params>(
|
||||
streams, gpu_indexes, gpu_count, lwe_array, all_shifted_buffer,
|
||||
terms_degree, bsks, ksks, mem->sum_ciphertexts_vec_mem,
|
||||
num_radix_blocks, j);
|
||||
num_radix_blocks, j, nullptr);
|
||||
|
||||
auto scp_mem_ptr = mem->sum_ciphertexts_vec_mem->scp_mem;
|
||||
host_propagate_single_carry<T>(streams, gpu_indexes, gpu_count, lwe_array,
|
||||
|
||||
@@ -1,15 +1,5 @@
|
||||
#include "programmable_bootstrap_amortized.cuh"
|
||||
|
||||
/*
|
||||
* Returns the buffer size for 64 bits executions
|
||||
*/
|
||||
uint64_t get_buffer_size_programmable_bootstrap_amortized_64(
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t input_lwe_ciphertext_count) {
|
||||
return get_buffer_size_programmable_bootstrap_amortized<uint64_t>(
|
||||
glwe_dimension, polynomial_size, input_lwe_ciphertext_count);
|
||||
}
|
||||
|
||||
/*
|
||||
* This scratch function allocates the necessary amount of data on the GPU for
|
||||
* the amortized PBS on 32 bits inputs, into `buffer`. It also
|
||||
|
||||
@@ -256,7 +256,7 @@ __host__ void execute_cg_external_product_loop(
|
||||
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t lwe_chunk_size, int lwe_offset) {
|
||||
uint32_t lwe_chunk_size, uint32_t lwe_offset) {
|
||||
|
||||
uint64_t full_dm =
|
||||
get_buffer_size_full_sm_cg_multibit_programmable_bootstrap<Torus>(
|
||||
@@ -275,6 +275,8 @@ __host__ void execute_cg_external_product_loop(
|
||||
|
||||
uint32_t chunk_size =
|
||||
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
|
||||
if (chunk_size == 0)
|
||||
return;
|
||||
|
||||
auto d_mem = buffer->d_mem_acc_cg;
|
||||
auto keybundle_fft = buffer->keybundle_fft;
|
||||
|
||||
@@ -182,25 +182,6 @@ void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector(
|
||||
}
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Returns the buffer size for 64 bits executions
|
||||
*/
|
||||
uint64_t get_buffer_size_programmable_bootstrap_64(
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t input_lwe_ciphertext_count) {
|
||||
|
||||
if (has_support_to_cuda_programmable_bootstrap_cg<uint64_t>(
|
||||
glwe_dimension, polynomial_size, level_count,
|
||||
input_lwe_ciphertext_count))
|
||||
return get_buffer_size_programmable_bootstrap_cg<uint64_t>(
|
||||
glwe_dimension, polynomial_size, level_count,
|
||||
input_lwe_ciphertext_count);
|
||||
else
|
||||
return get_buffer_size_programmable_bootstrap_cg<uint64_t>(
|
||||
glwe_dimension, polynomial_size, level_count,
|
||||
input_lwe_ciphertext_count);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
void scratch_cuda_programmable_bootstrap_cg(
|
||||
void *stream, uint32_t gpu_index, pbs_buffer<Torus, CLASSICAL> **pbs_buffer,
|
||||
|
||||
@@ -465,7 +465,7 @@ __host__ void execute_compute_keybundle(
|
||||
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t lwe_chunk_size, int lwe_offset) {
|
||||
uint32_t lwe_chunk_size, uint32_t lwe_offset) {
|
||||
|
||||
uint32_t chunk_size =
|
||||
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
|
||||
@@ -506,14 +506,12 @@ __host__ void execute_compute_keybundle(
|
||||
}
|
||||
|
||||
template <typename Torus, class params>
|
||||
__host__ void execute_step_one(cudaStream_t stream, uint32_t gpu_index,
|
||||
Torus *lut_vector, Torus *lut_vector_indexes,
|
||||
Torus *lwe_array_in, Torus *lwe_input_indexes,
|
||||
pbs_buffer<Torus, MULTI_BIT> *buffer,
|
||||
uint32_t num_samples, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, int j, int lwe_offset) {
|
||||
__host__ void execute_step_one(
|
||||
cudaStream_t stream, uint32_t gpu_index, Torus *lut_vector,
|
||||
Torus *lut_vector_indexes, Torus *lwe_array_in, Torus *lwe_input_indexes,
|
||||
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t j, uint32_t lwe_offset) {
|
||||
|
||||
uint64_t full_sm_accumulate_step_one =
|
||||
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_one<Torus>(
|
||||
@@ -562,14 +560,12 @@ __host__ void execute_step_one(cudaStream_t stream, uint32_t gpu_index,
|
||||
}
|
||||
|
||||
template <typename Torus, class params>
|
||||
__host__ void execute_step_two(cudaStream_t stream, uint32_t gpu_index,
|
||||
Torus *lwe_array_out, Torus *lwe_output_indexes,
|
||||
pbs_buffer<Torus, MULTI_BIT> *buffer,
|
||||
uint32_t num_samples, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size,
|
||||
int32_t grouping_factor, uint32_t level_count,
|
||||
int j, int lwe_offset, uint32_t lwe_chunk_size) {
|
||||
__host__ void execute_step_two(
|
||||
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
|
||||
Torus *lwe_output_indexes, pbs_buffer<Torus, MULTI_BIT> *buffer,
|
||||
uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, int32_t grouping_factor, uint32_t level_count,
|
||||
uint32_t j, uint32_t lwe_offset, uint32_t lwe_chunk_size) {
|
||||
|
||||
uint64_t full_sm_accumulate_step_two =
|
||||
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two<Torus>(
|
||||
@@ -627,7 +623,7 @@ __host__ void host_multi_bit_programmable_bootstrap(
|
||||
// Accumulate
|
||||
uint32_t chunk_size = std::min(
|
||||
lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
|
||||
for (int j = 0; j < chunk_size; j++) {
|
||||
for (uint32_t j = 0; j < chunk_size; j++) {
|
||||
execute_step_one<Torus, params>(
|
||||
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
|
||||
lwe_input_indexes, buffer, num_samples, lwe_dimension, glwe_dimension,
|
||||
|
||||
@@ -267,7 +267,7 @@ __host__ void execute_tbc_external_product_loop(
|
||||
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t lwe_chunk_size, int lwe_offset) {
|
||||
uint32_t lwe_chunk_size, uint32_t lwe_offset) {
|
||||
|
||||
auto supports_dsm =
|
||||
supports_distributed_shared_memory_on_multibit_programmable_bootstrap<
|
||||
@@ -294,6 +294,8 @@ __host__ void execute_tbc_external_product_loop(
|
||||
|
||||
uint32_t chunk_size =
|
||||
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
|
||||
if (chunk_size == 0)
|
||||
return;
|
||||
|
||||
auto d_mem = buffer->d_mem_acc_tbc;
|
||||
auto keybundle_fft = buffer->keybundle_fft;
|
||||
|
||||
@@ -222,6 +222,8 @@ __device__ void sample_extract_mask(Torus *lwe_array_out, Torus *glwe,
|
||||
Torus result[params::opt];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
// params::degree - tid - 1 can't be negative, tid goes from 0 to
|
||||
// params::degree - 1
|
||||
auto x = glwe_slice[params::degree - tid - 1];
|
||||
result[i] = SEL(-x, x, tid >= params::degree - nth);
|
||||
tid = tid + params::degree / params::opt;
|
||||
|
||||
@@ -747,7 +747,7 @@ extern "C" {
|
||||
gpu_indexes: *const u32,
|
||||
gpu_count: u32,
|
||||
radix_lwe_out: *mut c_void,
|
||||
radix_lwe_vec: *const c_void,
|
||||
radix_lwe_vec: *mut c_void,
|
||||
num_radix_in_vec: u32,
|
||||
mem_ptr: *mut i8,
|
||||
bsks: *const *mut c_void,
|
||||
@@ -959,7 +959,7 @@ extern "C" {
|
||||
gpu_indexes: *const u32,
|
||||
gpu_count: u32,
|
||||
output_radix_lwe: *mut c_void,
|
||||
input_radix_lwe: *const c_void,
|
||||
generates_or_propagates: *mut c_void,
|
||||
mem_ptr: *mut i8,
|
||||
ksks: *const *mut c_void,
|
||||
bsks: *const *mut c_void,
|
||||
|
||||
@@ -58,7 +58,7 @@ flavor_name = "n3-A100x8-NVLink"
|
||||
[backend.hyperstack.multi-gpu-test]
|
||||
environment_name = "canada"
|
||||
image_name = "Ubuntu Server 22.04 LTS R535 CUDA 12.2"
|
||||
flavor_name = "n3-A100x4"
|
||||
flavor_name = "n3-RTX-A6000x4"
|
||||
|
||||
[command.signed_integer_full_bench]
|
||||
workflow = "signed_integer_full_benchmark.yml"
|
||||
|
||||
@@ -133,8 +133,8 @@ if [[ "${backend}" == "gpu" ]]; then
|
||||
test_threads=8
|
||||
doctest_threads=8
|
||||
else
|
||||
test_threads=3
|
||||
doctest_threads=3
|
||||
test_threads=1
|
||||
doctest_threads=1
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@@ -1813,6 +1813,12 @@ mod cuda {
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: unchecked_signed_scalar_div_rem,
|
||||
display_name: div_rem,
|
||||
rng_func: div_scalar
|
||||
);
|
||||
|
||||
//===========================================
|
||||
// Default
|
||||
//===========================================
|
||||
@@ -2035,6 +2041,12 @@ mod cuda {
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: signed_scalar_div_rem,
|
||||
display_name: div_rem,
|
||||
rng_func: div_scalar
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
unchecked_cuda_ops,
|
||||
cuda_unchecked_add,
|
||||
@@ -2081,6 +2093,7 @@ mod cuda {
|
||||
cuda_unchecked_scalar_le,
|
||||
cuda_unchecked_scalar_min,
|
||||
cuda_unchecked_scalar_max,
|
||||
cuda_unchecked_signed_scalar_div_rem,
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
@@ -2146,6 +2159,7 @@ mod cuda {
|
||||
cuda_scalar_max,
|
||||
cuda_signed_overflowing_scalar_add,
|
||||
cuda_signed_overflowing_scalar_sub,
|
||||
cuda_signed_scalar_div_rem,
|
||||
);
|
||||
|
||||
fn cuda_bench_server_key_signed_cast_function<F>(
|
||||
|
||||
@@ -365,7 +365,6 @@ where
|
||||
// DivRem is a bit special as it returns a tuple of quotient and remainder
|
||||
macro_rules! generic_integer_impl_scalar_div_rem {
|
||||
(
|
||||
key_method: $key_method:ident,
|
||||
// A 'list' of tuple, where the first element is the concrete Fhe type
|
||||
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
|
||||
fhe_and_scalar_type: $(
|
||||
@@ -393,15 +392,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
|
||||
InternalServerKey::Cpu(cpu_key) => {
|
||||
let (q, r) = cpu_key
|
||||
.pbs_key()
|
||||
.$key_method(&*self.ciphertext.on_cpu(), rhs);
|
||||
.signed_scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
|
||||
(
|
||||
<$concrete_type>::new(q, cpu_key.tag.clone()),
|
||||
<$concrete_type>::new(r, cpu_key.tag.clone())
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_) => {
|
||||
panic!("Cuda devices does not support div rem yet")
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.signed_scalar_div_rem(
|
||||
&*self.ciphertext.on_gpu(), rhs, streams
|
||||
)
|
||||
});
|
||||
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
|
||||
(
|
||||
<$concrete_type>::new(q, cuda_key.tag.clone()),
|
||||
<$concrete_type>::new(r, cuda_key.tag.clone())
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -410,8 +418,8 @@ macro_rules! generic_integer_impl_scalar_div_rem {
|
||||
)* // Closing first repeating pattern
|
||||
};
|
||||
}
|
||||
|
||||
generic_integer_impl_scalar_div_rem!(
|
||||
key_method: signed_scalar_div_rem_parallelized,
|
||||
fhe_and_scalar_type:
|
||||
(super::FheInt2, i8),
|
||||
(super::FheInt4, i8),
|
||||
|
||||
@@ -446,7 +446,6 @@ where
|
||||
// DivRem is a bit special as it returns a tuple of quotient and remainder
|
||||
macro_rules! generic_integer_impl_scalar_div_rem {
|
||||
(
|
||||
key_method: $key_method:ident,
|
||||
// A 'list' of tuple, where the first element is the concrete Fhe type
|
||||
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
|
||||
fhe_and_scalar_type: $(
|
||||
@@ -473,15 +472,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
|
||||
global_state::with_internal_keys(|key| {
|
||||
match key {
|
||||
InternalServerKey::Cpu(cpu_key) => {
|
||||
let (q, r) = cpu_key.pbs_key().$key_method(&*self.ciphertext.on_cpu(), rhs);
|
||||
let (q, r) = cpu_key.pbs_key().scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
|
||||
(
|
||||
<$concrete_type>::new(q, cpu_key.tag.clone()),
|
||||
<$concrete_type>::new(r, cpu_key.tag.clone())
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_) => {
|
||||
panic!("Cuda devices do not support div_rem yet");
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.scalar_div_rem(
|
||||
&*self.ciphertext.on_gpu(), rhs, streams
|
||||
)
|
||||
});
|
||||
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
|
||||
(
|
||||
<$concrete_type>::new(q, cuda_key.tag.clone()),
|
||||
<$concrete_type>::new(r, cuda_key.tag.clone())
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -492,7 +500,6 @@ macro_rules! generic_integer_impl_scalar_div_rem {
|
||||
};
|
||||
}
|
||||
generic_integer_impl_scalar_div_rem!(
|
||||
key_method: scalar_div_rem_parallelized,
|
||||
fhe_and_scalar_type:
|
||||
(super::FheUint2, u8),
|
||||
(super::FheUint4, u8),
|
||||
|
||||
@@ -2578,7 +2578,7 @@ pub unsafe fn unchecked_signed_overflowing_add_or_sub_radix_kb_assign_async<
|
||||
pub unsafe fn compute_prefix_sum_hillis_steele_async<T: UnsignedInteger, B: Numeric>(
|
||||
streams: &CudaStreams,
|
||||
radix_lwe_output: &mut CudaSliceMut<T>,
|
||||
radix_lwe_input: &CudaSlice<T>,
|
||||
generates_or_propagates: &mut CudaSliceMut<T>,
|
||||
input_lut: &[T],
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
@@ -2598,7 +2598,7 @@ pub unsafe fn compute_prefix_sum_hillis_steele_async<T: UnsignedInteger, B: Nume
|
||||
) {
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
radix_lwe_input.gpu_index(0),
|
||||
generates_or_propagates.gpu_index(0),
|
||||
"GPU error: all data should reside on the same GPU."
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -2643,7 +2643,7 @@ pub unsafe fn compute_prefix_sum_hillis_steele_async<T: UnsignedInteger, B: Nume
|
||||
streams.gpu_indexes.as_ptr(),
|
||||
streams.len() as u32,
|
||||
radix_lwe_output.as_mut_c_ptr(0),
|
||||
radix_lwe_input.as_c_ptr(0),
|
||||
generates_or_propagates.as_mut_c_ptr(0),
|
||||
mem_ptr,
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
|
||||
@@ -40,8 +40,7 @@ impl CudaServerKey {
|
||||
let lwe_size = ct.as_ref().d_blocks.0.lwe_dimension.to_lwe_size().0;
|
||||
|
||||
// Allocate the necessary amount of memory
|
||||
let mut output_radix =
|
||||
CudaVec::new(num_ct_blocks * lwe_size, streams, streams.gpu_indexes[0]);
|
||||
let mut tmp_radix = CudaVec::new(num_ct_blocks * lwe_size, streams, streams.gpu_indexes[0]);
|
||||
|
||||
let lut = match direction {
|
||||
Direction::Trailing => self.generate_lookup_table(|x| {
|
||||
@@ -70,12 +69,12 @@ impl CudaServerKey {
|
||||
}),
|
||||
};
|
||||
|
||||
output_radix.copy_from_gpu_async(
|
||||
tmp_radix.copy_from_gpu_async(
|
||||
&ct.as_ref().d_blocks.0.d_vec,
|
||||
streams,
|
||||
streams.gpu_indexes[0],
|
||||
);
|
||||
let mut output_slice = output_radix
|
||||
let mut output_slice = tmp_radix
|
||||
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
|
||||
.unwrap();
|
||||
|
||||
@@ -167,27 +166,27 @@ impl CudaServerKey {
|
||||
},
|
||||
);
|
||||
|
||||
let mut cts = CudaLweCiphertextList::new(
|
||||
let mut output_cts = CudaLweCiphertextList::new(
|
||||
ct.as_ref().d_blocks.lwe_dimension(),
|
||||
LweCiphertextCount(num_ct_blocks * ct.as_ref().d_blocks.lwe_ciphertext_count().0),
|
||||
ct.as_ref().d_blocks.ciphertext_modulus(),
|
||||
streams,
|
||||
);
|
||||
|
||||
let input_radix_slice = output_radix
|
||||
.as_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
|
||||
let mut generates_or_propagates = tmp_radix
|
||||
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
|
||||
.unwrap();
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
compute_prefix_sum_hillis_steele_async(
|
||||
streams,
|
||||
&mut cts
|
||||
&mut output_cts
|
||||
.0
|
||||
.d_vec
|
||||
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
|
||||
.unwrap(),
|
||||
&input_radix_slice,
|
||||
&mut generates_or_propagates,
|
||||
sum_lut.acc.acc.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
@@ -211,12 +210,12 @@ impl CudaServerKey {
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
compute_prefix_sum_hillis_steele_async(
|
||||
streams,
|
||||
&mut cts
|
||||
&mut output_cts
|
||||
.0
|
||||
.d_vec
|
||||
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
|
||||
.unwrap(),
|
||||
&input_radix_slice,
|
||||
&mut generates_or_propagates,
|
||||
sum_lut.acc.acc.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
@@ -238,7 +237,7 @@ impl CudaServerKey {
|
||||
);
|
||||
}
|
||||
}
|
||||
cts
|
||||
output_cts
|
||||
}
|
||||
|
||||
/// Counts how many consecutive bits there are
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::Numeric;
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::ciphertext::{
|
||||
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::scalar_div_mod::choose_multiplier;
|
||||
use crate::integer::server_key::radix_parallel::scalar_div_mod::{
|
||||
choose_multiplier, SignedReciprocable,
|
||||
};
|
||||
use crate::integer::server_key::{MiniUnsignedInteger, Reciprocable, ScalarMultiplier};
|
||||
use crate::prelude::{CastFrom, CastInto};
|
||||
|
||||
@@ -32,6 +36,21 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
fn signed_scalar_mul_high<Scalar>(
|
||||
&self,
|
||||
lhs: &CudaSignedRadixCiphertext,
|
||||
rhs: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
{
|
||||
let num_blocks = lhs.as_ref().d_blocks.lwe_ciphertext_count().0;
|
||||
let mut result = self.extend_radix_with_sign_msb(lhs, num_blocks, streams);
|
||||
self.scalar_mul_assign(&mut result, rhs, streams);
|
||||
self.trim_radix_blocks_lsb(&result, num_blocks, streams)
|
||||
}
|
||||
|
||||
/// Computes homomorphically a division between a ciphertext and a scalar.
|
||||
///
|
||||
/// This function computes the operation without checking if it exceeds the capacity of the
|
||||
@@ -403,4 +422,262 @@ impl CudaServerKey {
|
||||
|
||||
self.unchecked_scalar_rem(numerator, divisor, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_signed_scalar_div<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
assert_ne!(divisor, Scalar::ZERO, "attempt to divide by 0");
|
||||
|
||||
let numerator_bits = self.message_modulus.0.ilog2()
|
||||
* numerator.ciphertext.d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
assert!(
|
||||
Scalar::BITS >= numerator_bits as usize,
|
||||
"The scalar divisor type must have a number of bits that is\
|
||||
>= to the number of bits encrypted in the ciphertext"
|
||||
);
|
||||
|
||||
// wrappings_abs returns Scalar::MIN when its input is Scalar::MIN (since in signed numbers
|
||||
// Scalar::MIN's absolute value cannot be represented.
|
||||
// However, casting Scalar::MIN to signed value will give the correct abs value
|
||||
// If Scalar and Scalar::Unsigned have the same number of bits
|
||||
let absolute_divisor = Scalar::Unsigned::cast_from(divisor.wrapping_abs());
|
||||
|
||||
if absolute_divisor == Scalar::Unsigned::ONE {
|
||||
// Strangely, the paper says: Issue q = d;
|
||||
return if divisor < Scalar::ZERO {
|
||||
// quotient = -quotient;
|
||||
self.neg(numerator, streams)
|
||||
} else {
|
||||
numerator.duplicate(streams)
|
||||
};
|
||||
}
|
||||
|
||||
let chosen_multiplier =
|
||||
choose_multiplier(absolute_divisor, numerator_bits - 1, numerator_bits);
|
||||
|
||||
if chosen_multiplier.l >= numerator_bits {
|
||||
return self.create_trivial_zero_radix(
|
||||
numerator.ciphertext.d_blocks.lwe_ciphertext_count().0,
|
||||
streams,
|
||||
);
|
||||
}
|
||||
|
||||
let quotient;
|
||||
if absolute_divisor == (Scalar::Unsigned::ONE << chosen_multiplier.l as usize) {
|
||||
// Issue q = SRA(n + SRL(SRA(n, l − 1), N − l), l);
|
||||
let l = chosen_multiplier.l;
|
||||
|
||||
// SRA(n, l − 1)
|
||||
let mut tmp = self.unchecked_scalar_right_shift(numerator, l - 1, streams);
|
||||
|
||||
// SRL(SRA(n, l − 1), N − l)
|
||||
unsafe {
|
||||
self.unchecked_scalar_right_shift_logical_assign_async(
|
||||
&mut tmp,
|
||||
(numerator_bits - l) as usize,
|
||||
streams,
|
||||
);
|
||||
}
|
||||
streams.synchronize();
|
||||
// n + SRL(SRA(n, l − 1), N − l)
|
||||
self.add_assign(&mut tmp, numerator, streams);
|
||||
// SRA(n + SRL(SRA(n, l − 1), N − l), l);
|
||||
quotient = self.unchecked_scalar_right_shift(&tmp, l, streams);
|
||||
} else if chosen_multiplier.multiplier
|
||||
< (<Scalar::Unsigned as Reciprocable>::DoublePrecision::ONE << (numerator_bits - 1))
|
||||
{
|
||||
// in the condition above works (it makes more values take this branch,
|
||||
// but results still seemed correct)
|
||||
|
||||
// multiplier is less than the max possible value of Scalar
|
||||
// Issue q = SRA(MULSH(m, n), shpost) − XSIGN(n);
|
||||
|
||||
let (mut tmp, xsign) = rayon::join(
|
||||
move || {
|
||||
// MULSH(m, n)
|
||||
let mut tmp = self.signed_scalar_mul_high(
|
||||
numerator,
|
||||
chosen_multiplier.multiplier,
|
||||
streams,
|
||||
);
|
||||
|
||||
// SRA(MULSH(m, n), shpost)
|
||||
unsafe {
|
||||
self.unchecked_scalar_right_shift_assign_async(
|
||||
&mut tmp,
|
||||
chosen_multiplier.shift_post,
|
||||
streams,
|
||||
);
|
||||
}
|
||||
streams.synchronize();
|
||||
tmp
|
||||
},
|
||||
|| {
|
||||
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
|
||||
// It is equivalent to SRA(x, N − 1)
|
||||
self.unchecked_scalar_right_shift(numerator, numerator_bits - 1, streams)
|
||||
},
|
||||
);
|
||||
|
||||
self.sub_assign(&mut tmp, &xsign, streams);
|
||||
quotient = tmp;
|
||||
} else {
|
||||
// Issue q = SRA(n + MULSH(m − 2^N , n), shpost) − XSIGN(n);
|
||||
// Note from the paper: m - 2^N is negative
|
||||
|
||||
let (mut tmp, xsign) = rayon::join(
|
||||
move || {
|
||||
// The subtraction may overflow.
|
||||
// We then cast the result to a signed type.
|
||||
// Overall, this will work fine due to two's complement representation
|
||||
let cst = chosen_multiplier.multiplier
|
||||
- (<Scalar::Unsigned as Reciprocable>::DoublePrecision::ONE
|
||||
<< numerator_bits);
|
||||
let cst = Scalar::DoublePrecision::cast_from(cst);
|
||||
|
||||
// MULSH(m - 2^N, n)
|
||||
let mut tmp = self.signed_scalar_mul_high(numerator, cst, streams);
|
||||
|
||||
// n + MULSH(m − 2^N , n)
|
||||
self.add_assign(&mut tmp, numerator, streams);
|
||||
|
||||
// SRA(n + MULSH(m - 2^N, n), shpost)
|
||||
tmp = self.unchecked_scalar_right_shift(
|
||||
&tmp,
|
||||
chosen_multiplier.shift_post,
|
||||
streams,
|
||||
);
|
||||
|
||||
tmp
|
||||
},
|
||||
|| {
|
||||
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
|
||||
// It is equivalent to SRA(x, N − 1)
|
||||
self.unchecked_scalar_right_shift(numerator, numerator_bits - 1, streams)
|
||||
},
|
||||
);
|
||||
|
||||
self.sub_assign(&mut tmp, &xsign, streams);
|
||||
quotient = tmp;
|
||||
}
|
||||
|
||||
if divisor < Scalar::ZERO {
|
||||
self.neg("ient, streams)
|
||||
} else {
|
||||
quotient
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signed_scalar_div<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
let mut tmp_numerator;
|
||||
let numerator = if numerator.block_carries_are_empty() {
|
||||
numerator
|
||||
} else {
|
||||
unsafe {
|
||||
tmp_numerator = numerator.duplicate_async(streams);
|
||||
self.full_propagate_assign_async(&mut tmp_numerator, streams);
|
||||
}
|
||||
&tmp_numerator
|
||||
};
|
||||
|
||||
self.unchecked_signed_scalar_div(numerator, divisor, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_signed_scalar_div_rem<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
let quotient = self.unchecked_signed_scalar_div(numerator, divisor, streams);
|
||||
|
||||
// remainder = numerator - (quotient * divisor)
|
||||
let tmp = self.unchecked_scalar_mul("ient, divisor, streams);
|
||||
let remainder = self.sub(numerator, &tmp, streams);
|
||||
|
||||
(quotient, remainder)
|
||||
}
|
||||
|
||||
pub fn signed_scalar_div_rem<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
let mut tmp_numerator;
|
||||
let numerator = if numerator.block_carries_are_empty() {
|
||||
numerator
|
||||
} else {
|
||||
unsafe {
|
||||
tmp_numerator = numerator.duplicate_async(streams);
|
||||
self.full_propagate_assign_async(&mut tmp_numerator, streams);
|
||||
}
|
||||
&tmp_numerator
|
||||
};
|
||||
|
||||
self.unchecked_signed_scalar_div_rem(numerator, divisor, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_signed_scalar_rem<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
let (_, remainder) = self.unchecked_signed_scalar_div_rem(numerator, divisor, streams);
|
||||
|
||||
remainder
|
||||
}
|
||||
|
||||
pub fn signed_scalar_rem<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
let mut tmp_numerator;
|
||||
let numerator = if numerator.block_carries_are_empty() {
|
||||
numerator
|
||||
} else {
|
||||
unsafe {
|
||||
tmp_numerator = numerator.duplicate_async(streams);
|
||||
self.full_propagate_assign_async(&mut tmp_numerator, streams);
|
||||
}
|
||||
&tmp_numerator
|
||||
};
|
||||
|
||||
self.unchecked_signed_scalar_rem(numerator, divisor, streams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -565,4 +565,76 @@ impl CudaServerKey {
|
||||
};
|
||||
stream.synchronize();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_right_shift_logical_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: Scalar,
|
||||
stream: &CudaStreams,
|
||||
) where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
unchecked_scalar_logical_right_shift_integer_radix_kb_assign_async(
|
||||
stream,
|
||||
&mut ct.as_mut().d_blocks.0.d_vec,
|
||||
u32::cast_from(shift),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
unchecked_scalar_logical_right_shift_integer_radix_kb_assign_async(
|
||||
stream,
|
||||
&mut ct.as_mut().d_blocks.0.d_vec,
|
||||
u32::cast_from(shift),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ pub(crate) mod test_rotate;
|
||||
pub(crate) mod test_scalar_add;
|
||||
pub(crate) mod test_scalar_bitwise_op;
|
||||
pub(crate) mod test_scalar_comparison;
|
||||
pub(crate) mod test_scalar_div_mod;
|
||||
pub(crate) mod test_scalar_mul;
|
||||
pub(crate) mod test_scalar_rotate;
|
||||
pub(crate) mod test_scalar_shift;
|
||||
@@ -265,6 +266,46 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// For unchecked/default binary functions with one scalar input and two encrypted outputs
|
||||
impl<'a, F>
|
||||
FunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, i64),
|
||||
(SignedRadixCiphertext, SignedRadixCiphertext),
|
||||
> for GpuFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
i64,
|
||||
&CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a SignedRadixCiphertext, i64),
|
||||
) -> (SignedRadixCiphertext, SignedRadixCiphertext) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let (gpu_result_1, gpu_result_2) =
|
||||
(self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
(
|
||||
gpu_result_1.to_signed_radix_ciphertext(&context.streams),
|
||||
gpu_result_2.to_signed_radix_ciphertext(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), BooleanBlock>
|
||||
for GpuFunctionExecutor<F>
|
||||
where
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
create_gpu_parametrized_test, GpuFunctionExecutor,
|
||||
};
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_div_mod::signed_unchecked_scalar_div_rem_test;
|
||||
use crate::shortint::parameters::*;
|
||||
|
||||
create_gpu_parametrized_test!(integer_signed_unchecked_scalar_div_rem);
|
||||
|
||||
fn integer_signed_unchecked_scalar_div_rem<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::signed_scalar_div_rem);
|
||||
signed_unchecked_scalar_div_rem_test(param, executor);
|
||||
}
|
||||
@@ -135,6 +135,7 @@ pub trait SignedReciprocable:
|
||||
type DoublePrecision: DecomposableInto<u8>
|
||||
+ ScalarMultiplier
|
||||
+ CastFrom<<Self::Unsigned as Reciprocable>::DoublePrecision>
|
||||
+ CastInto<u64>
|
||||
+ std::fmt::Debug;
|
||||
|
||||
fn wrapping_abs(self) -> Self;
|
||||
|
||||
@@ -10,6 +10,7 @@ pub(crate) mod test_rotate;
|
||||
pub(crate) mod test_scalar_add;
|
||||
pub(crate) mod test_scalar_bitwise_op;
|
||||
pub(crate) mod test_scalar_comparison;
|
||||
pub(crate) mod test_scalar_div_mod;
|
||||
pub(crate) mod test_scalar_mul;
|
||||
pub(crate) mod test_scalar_rotate;
|
||||
pub(crate) mod test_scalar_shift;
|
||||
@@ -488,154 +489,8 @@ fn integer_signed_default_absolute_value(param: impl Into<PBSParameters>) {
|
||||
//================================================================================
|
||||
// Unchecked Scalar Tests
|
||||
//================================================================================
|
||||
create_parametrized_test!(integer_signed_unchecked_scalar_div_rem);
|
||||
create_parametrized_test!(integer_signed_unchecked_scalar_div_rem_floor);
|
||||
|
||||
fn integer_signed_unchecked_scalar_div_rem(param: impl Into<PBSParameters>) {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
|
||||
|
||||
{
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
|
||||
|
||||
let result = std::panic::catch_unwind(|| {
|
||||
let _ = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, 0);
|
||||
});
|
||||
assert!(result.is_err(), "Division by zero did not panic");
|
||||
}
|
||||
|
||||
// check when scalar is out of ciphertext MIN..=MAX
|
||||
for d in [
|
||||
rng.gen_range(i64::MIN..-modulus),
|
||||
rng.gen_range(modulus..=i64::MAX),
|
||||
] {
|
||||
for numerator in [rng.gen_range(-modulus..=0), rng.gen_range(0..modulus)] {
|
||||
let ctxt_0 = cks.encrypt_signed_radix(numerator, NB_CTXT);
|
||||
|
||||
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
|
||||
let q: i64 = cks.decrypt_signed_radix(&q_res);
|
||||
let r: i64 = cks.decrypt_signed_radix(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(numerator, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(numerator, d, modulus));
|
||||
}
|
||||
}
|
||||
|
||||
// The algorithm has a special case for when divisor is 1 or -1
|
||||
for d in [1i64, -1i64] {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
|
||||
|
||||
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
|
||||
let q: i64 = cks.decrypt_signed_radix(&q_res);
|
||||
let r: i64 = cks.decrypt_signed_radix(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(clear_0, d, modulus));
|
||||
}
|
||||
|
||||
// 3 / -3 takes the second branch in the if else if series
|
||||
for d in [3, -3] {
|
||||
{
|
||||
let neg_clear_0 = rng.gen_range(-modulus..=0);
|
||||
let ctxt_0 = cks.encrypt_signed_radix(neg_clear_0, NB_CTXT);
|
||||
println!("{neg_clear_0} / {d}");
|
||||
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
|
||||
let q: i64 = cks.decrypt_signed_radix(&q_res);
|
||||
let r: i64 = cks.decrypt_signed_radix(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
|
||||
}
|
||||
|
||||
{
|
||||
let pos_clear_0 = rng.gen_range(0..modulus);
|
||||
let ctxt_0 = cks.encrypt_signed_radix(pos_clear_0, NB_CTXT);
|
||||
println!("{pos_clear_0} / {d}");
|
||||
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
|
||||
let q: i64 = cks.decrypt_signed_radix(&q_res);
|
||||
let r: i64 = cks.decrypt_signed_radix(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
|
||||
}
|
||||
}
|
||||
|
||||
// Param 1_1 cannot do this, with our NB_CTXT
|
||||
if modulus >= 43 {
|
||||
// For param_2_2 this will take the third branch in the if else if series
|
||||
for d in [-89, 89] {
|
||||
{
|
||||
let neg_clear_0 = rng.gen_range(-modulus..=0);
|
||||
let ctxt_0 = cks.encrypt_signed_radix(neg_clear_0, NB_CTXT);
|
||||
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
|
||||
let q: i64 = cks.decrypt_signed_radix(&q_res);
|
||||
let r: i64 = cks.decrypt_signed_radix(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
|
||||
}
|
||||
|
||||
{
|
||||
let pos_clear_0 = rng.gen_range(0..modulus);
|
||||
let ctxt_0 = cks.encrypt_signed_radix(pos_clear_0, NB_CTXT);
|
||||
println!("{pos_clear_0} / {d}");
|
||||
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
|
||||
let q: i64 = cks.decrypt_signed_radix(&q_res);
|
||||
let r: i64 = cks.decrypt_signed_radix(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
|
||||
}
|
||||
}
|
||||
|
||||
// For param_2_2 this will take the first branch
|
||||
for (clear_0, clear_1) in [(43, 8), (43, -8), (-43, 8), (-43, -8)] {
|
||||
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
|
||||
|
||||
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, clear_1);
|
||||
let q: i64 = cks.decrypt_signed_radix(&q_res);
|
||||
let r: i64 = cks.decrypt_signed_radix(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(clear_0, clear_1, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(clear_0, clear_1, modulus));
|
||||
}
|
||||
}
|
||||
|
||||
for d in [-modulus, modulus - 1] {
|
||||
{
|
||||
let neg_clear_0 = rng.gen_range(-modulus..=0);
|
||||
let ctxt_0 = cks.encrypt_signed_radix(neg_clear_0, NB_CTXT);
|
||||
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
|
||||
let q: i64 = cks.decrypt_signed_radix(&q_res);
|
||||
let r: i64 = cks.decrypt_signed_radix(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
|
||||
}
|
||||
|
||||
{
|
||||
let pos_clear_0 = rng.gen_range(0..modulus);
|
||||
let ctxt_0 = cks.encrypt_signed_radix(pos_clear_0, NB_CTXT);
|
||||
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
|
||||
let q: i64 = cks.decrypt_signed_radix(&q_res);
|
||||
let r: i64 = cks.decrypt_signed_radix(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
|
||||
}
|
||||
}
|
||||
|
||||
let lhs_values = random_signed_value_under_modulus::<6>(&mut rng, modulus);
|
||||
let rhs_values = random_non_zero_signed_value_under_modulus::<6>(&mut rng, modulus);
|
||||
|
||||
for (clear_lhs, clear_rhs) in iproduct!(lhs_values, rhs_values) {
|
||||
let ctxt_0 = cks.encrypt_signed_radix(clear_lhs, NB_CTXT);
|
||||
|
||||
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, clear_rhs);
|
||||
let q: i64 = cks.decrypt_signed_radix(&q_res);
|
||||
let r: i64 = cks.decrypt_signed_radix(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(clear_lhs, clear_rhs, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(clear_lhs, clear_rhs, modulus));
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_signed_unchecked_scalar_div_rem_floor(param: impl Into<PBSParameters>) {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
use crate::integer::ciphertext::SignedRadixCiphertext;
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
|
||||
use crate::integer::server_key::radix_parallel::tests_signed::{
|
||||
random_non_zero_signed_value_under_modulus, random_signed_value_under_modulus,
|
||||
signed_div_under_modulus, signed_rem_under_modulus, NB_CTXT,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
|
||||
use crate::integer::tests::create_parametrized_test;
|
||||
use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey};
|
||||
#[cfg(tarpaulin)]
|
||||
use crate::shortint::parameters::coverage_parameters::*;
|
||||
use crate::shortint::parameters::*;
|
||||
use itertools::iproduct;
|
||||
use rand::prelude::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
create_parametrized_test!(integer_signed_unchecked_scalar_div_rem);
|
||||
|
||||
fn integer_signed_unchecked_scalar_div_rem<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::unchecked_signed_scalar_div_rem_parallelized);
|
||||
signed_unchecked_scalar_div_rem_test(param, executor);
|
||||
}
|
||||
|
||||
pub(crate) fn signed_unchecked_scalar_div_rem_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, i64),
|
||||
(SignedRadixCiphertext, SignedRadixCiphertext),
|
||||
> + std::panic::UnwindSafe,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
|
||||
|
||||
executor.setup(&cks, sks.clone());
|
||||
|
||||
// check when scalar is out of ciphertext MIN..=MAX
|
||||
for d in [
|
||||
rng.gen_range(i64::MIN..-modulus),
|
||||
rng.gen_range(modulus..=i64::MAX),
|
||||
] {
|
||||
for numerator in [rng.gen_range(-modulus..=0), rng.gen_range(0..modulus)] {
|
||||
let ctxt_0 = cks.encrypt_signed(numerator);
|
||||
|
||||
let (q_res, r_res) = executor.execute((&ctxt_0, d));
|
||||
let q: i64 = cks.decrypt_signed(&q_res);
|
||||
let r: i64 = cks.decrypt_signed(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(numerator, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(numerator, d, modulus));
|
||||
}
|
||||
}
|
||||
|
||||
// The algorithm has a special case for when divisor is 1 or -1
|
||||
for d in [1i64, -1i64] {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
|
||||
let (q_res, r_res) = executor.execute((&ctxt_0, d));
|
||||
let q: i64 = cks.decrypt_signed(&q_res);
|
||||
let r: i64 = cks.decrypt_signed(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(clear_0, d, modulus));
|
||||
}
|
||||
|
||||
// 3 / -3 takes the second branch in the if else if series
|
||||
for d in [3, -3] {
|
||||
{
|
||||
let neg_clear_0 = rng.gen_range(-modulus..=0);
|
||||
let ctxt_0 = cks.encrypt_signed(neg_clear_0);
|
||||
println!("{neg_clear_0} / {d}");
|
||||
let (q_res, r_res) = executor.execute((&ctxt_0, d));
|
||||
let q: i64 = cks.decrypt_signed(&q_res);
|
||||
let r: i64 = cks.decrypt_signed(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
|
||||
}
|
||||
|
||||
{
|
||||
let pos_clear_0 = rng.gen_range(0..modulus);
|
||||
let ctxt_0 = cks.encrypt_signed(pos_clear_0);
|
||||
println!("{pos_clear_0} / {d}");
|
||||
let (q_res, r_res) = executor.execute((&ctxt_0, d));
|
||||
let q: i64 = cks.decrypt_signed(&q_res);
|
||||
let r: i64 = cks.decrypt_signed(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
|
||||
}
|
||||
}
|
||||
|
||||
// Param 1_1 cannot do this, with our NB_CTXT
|
||||
if modulus >= 43 {
|
||||
// For param_2_2 this will take the third branch in the if else if series
|
||||
for d in [-89, 89] {
|
||||
{
|
||||
let neg_clear_0 = rng.gen_range(-modulus..=0);
|
||||
let ctxt_0 = cks.encrypt_signed(neg_clear_0);
|
||||
let (q_res, r_res) = executor.execute((&ctxt_0, d));
|
||||
let q: i64 = cks.decrypt_signed(&q_res);
|
||||
let r: i64 = cks.decrypt_signed(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
|
||||
}
|
||||
|
||||
{
|
||||
let pos_clear_0 = rng.gen_range(0..modulus);
|
||||
let ctxt_0 = cks.encrypt_signed(pos_clear_0);
|
||||
println!("{pos_clear_0} / {d}");
|
||||
let (q_res, r_res) = executor.execute((&ctxt_0, d));
|
||||
let q: i64 = cks.decrypt_signed(&q_res);
|
||||
let r: i64 = cks.decrypt_signed(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
|
||||
}
|
||||
}
|
||||
|
||||
// For param_2_2 this will take the first branch
|
||||
for (clear_0, clear_1) in [(43, 8), (43, -8), (-43, 8), (-43, -8)] {
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
|
||||
let (q_res, r_res) = executor.execute((&ctxt_0, clear_1));
|
||||
let q: i64 = cks.decrypt_signed(&q_res);
|
||||
let r: i64 = cks.decrypt_signed(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(clear_0, clear_1, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(clear_0, clear_1, modulus));
|
||||
}
|
||||
}
|
||||
|
||||
for d in [-modulus, modulus - 1] {
|
||||
{
|
||||
let neg_clear_0 = rng.gen_range(-modulus..=0);
|
||||
let ctxt_0 = cks.encrypt_signed(neg_clear_0);
|
||||
let (q_res, r_res) = executor.execute((&ctxt_0, d));
|
||||
let q: i64 = cks.decrypt_signed(&q_res);
|
||||
let r: i64 = cks.decrypt_signed(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
|
||||
}
|
||||
|
||||
{
|
||||
let pos_clear_0 = rng.gen_range(0..modulus);
|
||||
let ctxt_0 = cks.encrypt_signed(pos_clear_0);
|
||||
let (q_res, r_res) = executor.execute((&ctxt_0, d));
|
||||
let q: i64 = cks.decrypt_signed(&q_res);
|
||||
let r: i64 = cks.decrypt_signed(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
|
||||
}
|
||||
}
|
||||
|
||||
let lhs_values = random_signed_value_under_modulus::<6>(&mut rng, modulus);
|
||||
let rhs_values = random_non_zero_signed_value_under_modulus::<6>(&mut rng, modulus);
|
||||
|
||||
for (clear_lhs, clear_rhs) in iproduct!(lhs_values, rhs_values) {
|
||||
let ctxt_0 = cks.encrypt_signed(clear_lhs);
|
||||
|
||||
let (q_res, r_res) = executor.execute((&ctxt_0, clear_rhs));
|
||||
let q: i64 = cks.decrypt_signed(&q_res);
|
||||
let r: i64 = cks.decrypt_signed(&r_res);
|
||||
assert_eq!(q, signed_div_under_modulus(clear_lhs, clear_rhs, modulus));
|
||||
assert_eq!(r, signed_rem_under_modulus(clear_lhs, clear_rhs, modulus));
|
||||
}
|
||||
|
||||
// Do this test last, so we can move the executor into the closure
|
||||
let result = std::panic::catch_unwind(move || {
|
||||
let numerator = sks.create_trivial_radix(1, NB_CTXT);
|
||||
executor.execute((&numerator, 0i64));
|
||||
});
|
||||
assert!(result.is_err(), "division by zero should panic");
|
||||
}
|
||||
Reference in New Issue
Block a user