Compare commits

...

8 Commits

Author SHA1 Message Date
Agnes Leroy
a01814e991 chore(gpu): change multi-gpu tests to run on rtx so it's cheaper 2024-09-12 14:03:42 +02:00
Agnes Leroy
b5d50cec5a chore(gpu): set test threads to 1 when BIG_INSTANCE is false to get a better view of failures in the ci 2024-09-12 13:24:07 +02:00
Agnes Leroy
c619eb479e chore(gpu): add comment, remove unnecessary sync 2024-09-12 13:24:07 +02:00
Agnes Leroy
4e8bdc4380 chore(gpu): add scalar div and signed scalar div to hl api 2024-09-12 10:49:41 +02:00
Agnes Leroy
1deaaf5249 feat(gpu): signed scalar div 2024-09-12 10:49:41 +02:00
Agnes Leroy
abd2fe1f4e chore(gpu): return if chunk_size is 0 2024-09-12 10:49:41 +02:00
Agnes Leroy
47d671b043 fix(gpu): return early in sum_ct if num radix is 2, pass different pointers to smart copy 2024-09-12 10:40:13 +02:00
Agnes Leroy
f700016776 chore(gpu): fix partial sum ct with 0 or 1 inputs in the vec
Also refactor the interface for Hillis & Steele prefix sum
2024-09-12 09:22:42 +02:00
28 changed files with 720 additions and 254 deletions

View File

@@ -146,8 +146,8 @@ jobs:
- name: Run core crypto and internal CUDA backend tests
run: |
BIG_TESTS_INSTANCE=TRUE make test_core_crypto_gpu
BIG_TESTS_INSTANCE=TRUE make test_integer_compression_gpu
BIG_TESTS_INSTANCE=FALSE make test_core_crypto_gpu
BIG_TESTS_INSTANCE=FALSE make test_integer_compression_gpu
BIG_TESTS_INSTANCE=TRUE make test_cuda_backend
- name: Run user docs tests

View File

@@ -386,8 +386,8 @@ void scratch_cuda_integer_compute_prefix_sum_hillis_steele_64(
void cuda_integer_compute_prefix_sum_hillis_steele_64(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count,
void *output_radix_lwe, void *input_radix_lwe, int8_t *mem_ptr, void **ksks,
void **bsks, uint32_t num_blocks, uint32_t shift);
void *output_radix_lwe, void *generates_or_propagates, int8_t *mem_ptr,
void **ksks, void **bsks, uint32_t num_blocks, uint32_t shift);
void cleanup_cuda_integer_compute_prefix_sum_hillis_steele_64(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count,
@@ -1357,6 +1357,7 @@ template <typename Torus> struct int_overflowing_sub_memory {
template <typename Torus> struct int_sum_ciphertexts_vec_memory {
Torus *new_blocks;
Torus *new_blocks_copy;
Torus *old_blocks;
Torus *small_lwe_vector;
int_radix_params params;
@@ -1384,6 +1385,9 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
new_blocks = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
new_blocks_copy = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
old_blocks = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
@@ -1415,6 +1419,9 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
this->new_blocks = new_blocks;
this->old_blocks = old_blocks;
this->small_lwe_vector = small_lwe_vector;
new_blocks_copy = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
d_smart_copy_in = (int32_t *)cuda_malloc_async(
max_pbs_count * sizeof(int32_t), streams[0], gpu_indexes[0]);
@@ -1433,8 +1440,8 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
cuda_drop_async(small_lwe_vector, streams[0], gpu_indexes[0]);
}
cuda_drop_async(new_blocks_copy, streams[0], gpu_indexes[0]);
scp_mem->release(streams, gpu_indexes, gpu_count);
delete scp_mem;
}
};

View File

@@ -81,14 +81,6 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_64(
void cleanup_cuda_programmable_bootstrap(void *stream, uint32_t gpu_index,
int8_t **pbs_buffer);
uint64_t get_buffer_size_programmable_bootstrap_amortized_64(
uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t input_lwe_ciphertext_count);
uint64_t get_buffer_size_programmable_bootstrap_64(
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count);
}
template <typename Torus>

View File

@@ -195,15 +195,15 @@ void scratch_cuda_integer_compute_prefix_sum_hillis_steele_64(
void cuda_integer_compute_prefix_sum_hillis_steele_64(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count,
void *output_radix_lwe, void *input_radix_lwe, int8_t *mem_ptr, void **ksks,
void **bsks, uint32_t num_blocks, uint32_t shift) {
void *output_radix_lwe, void *generates_or_propagates, int8_t *mem_ptr,
void **ksks, void **bsks, uint32_t num_blocks, uint32_t shift) {
int_radix_params params = ((int_radix_lut<uint64_t> *)mem_ptr)->params;
host_compute_prefix_sum_hillis_steele<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(output_radix_lwe),
static_cast<uint64_t *>(input_radix_lwe), params,
static_cast<uint64_t *>(generates_or_propagates), params,
(int_radix_lut<uint64_t> *)mem_ptr, bsks, (uint64_t **)(ksks),
num_blocks);
}

View File

@@ -241,7 +241,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
case 1024:
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
@@ -249,7 +250,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
case 2048:
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
@@ -257,7 +259,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
case 4096:
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
@@ -265,7 +268,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
case 8192:
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
@@ -273,7 +277,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
case 16384:
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
@@ -281,7 +286,8 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
default:
PANIC("Cuda error (integer multiplication): unsupported polynomial size. "

View File

@@ -186,9 +186,10 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
Torus *radix_lwe_out, Torus *terms, int *terms_degree, void **bsks,
uint64_t **ksks, int_sum_ciphertexts_vec_memory<uint64_t> *mem_ptr,
uint32_t num_blocks_in_radix, uint32_t num_radix_in_vec,
int_radix_lut<Torus> *reused_lut = nullptr) {
int_radix_lut<Torus> *reused_lut) {
auto new_blocks = mem_ptr->new_blocks;
auto new_blocks_copy = mem_ptr->new_blocks_copy;
auto old_blocks = mem_ptr->old_blocks;
auto small_lwe_vector = mem_ptr->small_lwe_vector;
@@ -205,12 +206,27 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
auto small_lwe_dimension = mem_ptr->params.small_lwe_dimension;
auto small_lwe_size = small_lwe_dimension + 1;
if (num_radix_in_vec == 0)
return;
if (num_radix_in_vec == 1) {
cuda_memcpy_async_gpu_to_gpu(radix_lwe_out, terms,
num_blocks_in_radix * big_lwe_size *
sizeof(Torus),
streams[0], gpu_indexes[0]);
return;
}
if (old_blocks != terms) {
cuda_memcpy_async_gpu_to_gpu(old_blocks, terms,
num_blocks_in_radix * num_radix_in_vec *
big_lwe_size * sizeof(Torus),
streams[0], gpu_indexes[0]);
}
if (num_radix_in_vec == 2) {
host_addition<Torus>(streams[0], gpu_indexes[0], radix_lwe_out, old_blocks,
&old_blocks[num_blocks * big_lwe_size],
big_lwe_dimension, num_blocks);
return;
}
size_t r = num_radix_in_vec;
size_t total_modulus = message_modulus * carry_modulus;
@@ -287,7 +303,6 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
terms_degree, h_lwe_idx_in, h_lwe_idx_out, h_smart_copy_in,
h_smart_copy_out, ch_amount, r, num_blocks, chunk_size, message_max,
total_count, message_count, carry_count, sm_copy_count);
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
auto lwe_indexes_in = luts_message_carry->lwe_indexes_in;
auto lwe_indexes_out = luts_message_carry->lwe_indexes_out;
luts_message_carry->set_lwe_indexes(streams[0], gpu_indexes[0],
@@ -302,8 +317,11 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
// inside d_smart_copy_in there are only -1 values
// it's fine to call smart_copy with same pointer
// as source and destination
cuda_memcpy_async_gpu_to_gpu(new_blocks_copy, new_blocks,
r * num_blocks * big_lwe_size * sizeof(Torus),
streams[0], gpu_indexes[0]);
smart_copy<Torus><<<sm_copy_count, 1024, 0, streams[0]>>>(
new_blocks, new_blocks, d_smart_copy_out, d_smart_copy_in,
new_blocks, new_blocks_copy, d_smart_copy_out, d_smart_copy_in,
big_lwe_size);
check_cuda_error(cudaGetLastError());

View File

@@ -91,7 +91,6 @@ __host__ void host_integer_scalar_mul_radix(
j++;
}
}
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
cuda_drop_async(preshifted_buffer, streams[0], gpu_indexes[0]);
mem->logical_scalar_shift_buffer->release(streams, gpu_indexes, gpu_count);
@@ -109,7 +108,7 @@ __host__ void host_integer_scalar_mul_radix(
host_integer_partial_sum_ciphertexts_vec_kb<T, params>(
streams, gpu_indexes, gpu_count, lwe_array, all_shifted_buffer,
terms_degree, bsks, ksks, mem->sum_ciphertexts_vec_mem,
num_radix_blocks, j);
num_radix_blocks, j, nullptr);
auto scp_mem_ptr = mem->sum_ciphertexts_vec_mem->scp_mem;
host_propagate_single_carry<T>(streams, gpu_indexes, gpu_count, lwe_array,

View File

@@ -1,15 +1,5 @@
#include "programmable_bootstrap_amortized.cuh"
/*
* Returns the buffer size for 64 bits executions
*/
uint64_t get_buffer_size_programmable_bootstrap_amortized_64(
uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t input_lwe_ciphertext_count) {
return get_buffer_size_programmable_bootstrap_amortized<uint64_t>(
glwe_dimension, polynomial_size, input_lwe_ciphertext_count);
}
/*
* This scratch function allocates the necessary amount of data on the GPU for
* the amortized PBS on 32 bits inputs, into `buffer`. It also

View File

@@ -256,7 +256,7 @@ __host__ void execute_cg_external_product_loop(
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
uint32_t lwe_chunk_size, int lwe_offset) {
uint32_t lwe_chunk_size, uint32_t lwe_offset) {
uint64_t full_dm =
get_buffer_size_full_sm_cg_multibit_programmable_bootstrap<Torus>(
@@ -275,6 +275,8 @@ __host__ void execute_cg_external_product_loop(
uint32_t chunk_size =
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
if (chunk_size == 0)
return;
auto d_mem = buffer->d_mem_acc_cg;
auto keybundle_fft = buffer->keybundle_fft;

View File

@@ -182,25 +182,6 @@ void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector(
}
#endif
/*
* Returns the buffer size for 64 bits executions
*/
uint64_t get_buffer_size_programmable_bootstrap_64(
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count) {
if (has_support_to_cuda_programmable_bootstrap_cg<uint64_t>(
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count))
return get_buffer_size_programmable_bootstrap_cg<uint64_t>(
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count);
else
return get_buffer_size_programmable_bootstrap_cg<uint64_t>(
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count);
}
template <typename Torus>
void scratch_cuda_programmable_bootstrap_cg(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, CLASSICAL> **pbs_buffer,

View File

@@ -465,7 +465,7 @@ __host__ void execute_compute_keybundle(
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
uint32_t lwe_chunk_size, int lwe_offset) {
uint32_t lwe_chunk_size, uint32_t lwe_offset) {
uint32_t chunk_size =
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
@@ -506,14 +506,12 @@ __host__ void execute_compute_keybundle(
}
template <typename Torus, class params>
__host__ void execute_step_one(cudaStream_t stream, uint32_t gpu_index,
Torus *lut_vector, Torus *lut_vector_indexes,
Torus *lwe_array_in, Torus *lwe_input_indexes,
pbs_buffer<Torus, MULTI_BIT> *buffer,
uint32_t num_samples, uint32_t lwe_dimension,
uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, int j, int lwe_offset) {
__host__ void execute_step_one(
cudaStream_t stream, uint32_t gpu_index, Torus *lut_vector,
Torus *lut_vector_indexes, Torus *lwe_array_in, Torus *lwe_input_indexes,
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t j, uint32_t lwe_offset) {
uint64_t full_sm_accumulate_step_one =
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_one<Torus>(
@@ -562,14 +560,12 @@ __host__ void execute_step_one(cudaStream_t stream, uint32_t gpu_index,
}
template <typename Torus, class params>
__host__ void execute_step_two(cudaStream_t stream, uint32_t gpu_index,
Torus *lwe_array_out, Torus *lwe_output_indexes,
pbs_buffer<Torus, MULTI_BIT> *buffer,
uint32_t num_samples, uint32_t lwe_dimension,
uint32_t glwe_dimension,
uint32_t polynomial_size,
int32_t grouping_factor, uint32_t level_count,
int j, int lwe_offset, uint32_t lwe_chunk_size) {
__host__ void execute_step_two(
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
Torus *lwe_output_indexes, pbs_buffer<Torus, MULTI_BIT> *buffer,
uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, int32_t grouping_factor, uint32_t level_count,
uint32_t j, uint32_t lwe_offset, uint32_t lwe_chunk_size) {
uint64_t full_sm_accumulate_step_two =
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two<Torus>(
@@ -627,7 +623,7 @@ __host__ void host_multi_bit_programmable_bootstrap(
// Accumulate
uint32_t chunk_size = std::min(
lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
for (int j = 0; j < chunk_size; j++) {
for (uint32_t j = 0; j < chunk_size; j++) {
execute_step_one<Torus, params>(
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
lwe_input_indexes, buffer, num_samples, lwe_dimension, glwe_dimension,

View File

@@ -267,7 +267,7 @@ __host__ void execute_tbc_external_product_loop(
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
uint32_t lwe_chunk_size, int lwe_offset) {
uint32_t lwe_chunk_size, uint32_t lwe_offset) {
auto supports_dsm =
supports_distributed_shared_memory_on_multibit_programmable_bootstrap<
@@ -294,6 +294,8 @@ __host__ void execute_tbc_external_product_loop(
uint32_t chunk_size =
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
if (chunk_size == 0)
return;
auto d_mem = buffer->d_mem_acc_tbc;
auto keybundle_fft = buffer->keybundle_fft;

View File

@@ -222,6 +222,8 @@ __device__ void sample_extract_mask(Torus *lwe_array_out, Torus *glwe,
Torus result[params::opt];
#pragma unroll
for (int i = 0; i < params::opt; i++) {
// params::degree - tid - 1 can't be negative, tid goes from 0 to
// params::degree - 1
auto x = glwe_slice[params::degree - tid - 1];
result[i] = SEL(-x, x, tid >= params::degree - nth);
tid = tid + params::degree / params::opt;

View File

@@ -747,7 +747,7 @@ extern "C" {
gpu_indexes: *const u32,
gpu_count: u32,
radix_lwe_out: *mut c_void,
radix_lwe_vec: *const c_void,
radix_lwe_vec: *mut c_void,
num_radix_in_vec: u32,
mem_ptr: *mut i8,
bsks: *const *mut c_void,
@@ -959,7 +959,7 @@ extern "C" {
gpu_indexes: *const u32,
gpu_count: u32,
output_radix_lwe: *mut c_void,
input_radix_lwe: *const c_void,
generates_or_propagates: *mut c_void,
mem_ptr: *mut i8,
ksks: *const *mut c_void,
bsks: *const *mut c_void,

View File

@@ -58,7 +58,7 @@ flavor_name = "n3-A100x8-NVLink"
[backend.hyperstack.multi-gpu-test]
environment_name = "canada"
image_name = "Ubuntu Server 22.04 LTS R535 CUDA 12.2"
flavor_name = "n3-A100x4"
flavor_name = "n3-RTX-A6000x4"
[command.signed_integer_full_bench]
workflow = "signed_integer_full_benchmark.yml"

View File

@@ -133,8 +133,8 @@ if [[ "${backend}" == "gpu" ]]; then
test_threads=8
doctest_threads=8
else
test_threads=3
doctest_threads=3
test_threads=1
doctest_threads=1
fi
fi

View File

@@ -1813,6 +1813,12 @@ mod cuda {
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_signed_scalar_div_rem,
display_name: div_rem,
rng_func: div_scalar
);
//===========================================
// Default
//===========================================
@@ -2035,6 +2041,12 @@ mod cuda {
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: signed_scalar_div_rem,
display_name: div_rem,
rng_func: div_scalar
);
criterion_group!(
unchecked_cuda_ops,
cuda_unchecked_add,
@@ -2081,6 +2093,7 @@ mod cuda {
cuda_unchecked_scalar_le,
cuda_unchecked_scalar_min,
cuda_unchecked_scalar_max,
cuda_unchecked_signed_scalar_div_rem,
);
criterion_group!(
@@ -2146,6 +2159,7 @@ mod cuda {
cuda_scalar_max,
cuda_signed_overflowing_scalar_add,
cuda_signed_overflowing_scalar_sub,
cuda_signed_scalar_div_rem,
);
fn cuda_bench_server_key_signed_cast_function<F>(

View File

@@ -365,7 +365,6 @@ where
// DivRem is a bit special as it returns a tuple of quotient and remainder
macro_rules! generic_integer_impl_scalar_div_rem {
(
key_method: $key_method:ident,
// A 'list' of tuple, where the first element is the concrete Fhe type
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
fhe_and_scalar_type: $(
@@ -393,15 +392,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
InternalServerKey::Cpu(cpu_key) => {
let (q, r) = cpu_key
.pbs_key()
.$key_method(&*self.ciphertext.on_cpu(), rhs);
.signed_scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
(
<$concrete_type>::new(q, cpu_key.tag.clone()),
<$concrete_type>::new(r, cpu_key.tag.clone())
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices does not support div rem yet")
InternalServerKey::Cuda(cuda_key) => {
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
cuda_key.key.signed_scalar_div_rem(
&*self.ciphertext.on_gpu(), rhs, streams
)
});
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
(
<$concrete_type>::new(q, cuda_key.tag.clone()),
<$concrete_type>::new(r, cuda_key.tag.clone())
)
}
})
}
@@ -410,8 +418,8 @@ macro_rules! generic_integer_impl_scalar_div_rem {
)* // Closing first repeating pattern
};
}
generic_integer_impl_scalar_div_rem!(
key_method: signed_scalar_div_rem_parallelized,
fhe_and_scalar_type:
(super::FheInt2, i8),
(super::FheInt4, i8),

View File

@@ -446,7 +446,6 @@ where
// DivRem is a bit special as it returns a tuple of quotient and remainder
macro_rules! generic_integer_impl_scalar_div_rem {
(
key_method: $key_method:ident,
// A 'list' of tuple, where the first element is the concrete Fhe type
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
fhe_and_scalar_type: $(
@@ -473,15 +472,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
global_state::with_internal_keys(|key| {
match key {
InternalServerKey::Cpu(cpu_key) => {
let (q, r) = cpu_key.pbs_key().$key_method(&*self.ciphertext.on_cpu(), rhs);
let (q, r) = cpu_key.pbs_key().scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
(
<$concrete_type>::new(q, cpu_key.tag.clone()),
<$concrete_type>::new(r, cpu_key.tag.clone())
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support div_rem yet");
InternalServerKey::Cuda(cuda_key) => {
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
cuda_key.key.scalar_div_rem(
&*self.ciphertext.on_gpu(), rhs, streams
)
});
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
(
<$concrete_type>::new(q, cuda_key.tag.clone()),
<$concrete_type>::new(r, cuda_key.tag.clone())
)
}
}
})
@@ -492,7 +500,6 @@ macro_rules! generic_integer_impl_scalar_div_rem {
};
}
generic_integer_impl_scalar_div_rem!(
key_method: scalar_div_rem_parallelized,
fhe_and_scalar_type:
(super::FheUint2, u8),
(super::FheUint4, u8),

View File

@@ -2578,7 +2578,7 @@ pub unsafe fn unchecked_signed_overflowing_add_or_sub_radix_kb_assign_async<
pub unsafe fn compute_prefix_sum_hillis_steele_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
radix_lwe_output: &mut CudaSliceMut<T>,
radix_lwe_input: &CudaSlice<T>,
generates_or_propagates: &mut CudaSliceMut<T>,
input_lut: &[T],
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
@@ -2598,7 +2598,7 @@ pub unsafe fn compute_prefix_sum_hillis_steele_async<T: UnsignedInteger, B: Nume
) {
assert_eq!(
streams.gpu_indexes[0],
radix_lwe_input.gpu_index(0),
generates_or_propagates.gpu_index(0),
"GPU error: all data should reside on the same GPU."
);
assert_eq!(
@@ -2643,7 +2643,7 @@ pub unsafe fn compute_prefix_sum_hillis_steele_async<T: UnsignedInteger, B: Nume
streams.gpu_indexes.as_ptr(),
streams.len() as u32,
radix_lwe_output.as_mut_c_ptr(0),
radix_lwe_input.as_c_ptr(0),
generates_or_propagates.as_mut_c_ptr(0),
mem_ptr,
keyswitch_key.ptr.as_ptr(),
bootstrapping_key.ptr.as_ptr(),

View File

@@ -40,8 +40,7 @@ impl CudaServerKey {
let lwe_size = ct.as_ref().d_blocks.0.lwe_dimension.to_lwe_size().0;
// Allocate the necessary amount of memory
let mut output_radix =
CudaVec::new(num_ct_blocks * lwe_size, streams, streams.gpu_indexes[0]);
let mut tmp_radix = CudaVec::new(num_ct_blocks * lwe_size, streams, streams.gpu_indexes[0]);
let lut = match direction {
Direction::Trailing => self.generate_lookup_table(|x| {
@@ -70,12 +69,12 @@ impl CudaServerKey {
}),
};
output_radix.copy_from_gpu_async(
tmp_radix.copy_from_gpu_async(
&ct.as_ref().d_blocks.0.d_vec,
streams,
streams.gpu_indexes[0],
);
let mut output_slice = output_radix
let mut output_slice = tmp_radix
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
.unwrap();
@@ -167,27 +166,27 @@ impl CudaServerKey {
},
);
let mut cts = CudaLweCiphertextList::new(
let mut output_cts = CudaLweCiphertextList::new(
ct.as_ref().d_blocks.lwe_dimension(),
LweCiphertextCount(num_ct_blocks * ct.as_ref().d_blocks.lwe_ciphertext_count().0),
ct.as_ref().d_blocks.ciphertext_modulus(),
streams,
);
let input_radix_slice = output_radix
.as_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
let mut generates_or_propagates = tmp_radix
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
.unwrap();
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
compute_prefix_sum_hillis_steele_async(
streams,
&mut cts
&mut output_cts
.0
.d_vec
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
.unwrap(),
&input_radix_slice,
&mut generates_or_propagates,
sum_lut.acc.acc.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
@@ -211,12 +210,12 @@ impl CudaServerKey {
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
compute_prefix_sum_hillis_steele_async(
streams,
&mut cts
&mut output_cts
.0
.d_vec
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
.unwrap(),
&input_radix_slice,
&mut generates_or_propagates,
sum_lut.acc.acc.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
@@ -238,7 +237,7 @@ impl CudaServerKey {
);
}
}
cts
output_cts
}
/// Counts how many consecutive bits there are

View File

@@ -1,9 +1,13 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::Numeric;
use crate::integer::block_decomposition::DecomposableInto;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::scalar_div_mod::choose_multiplier;
use crate::integer::server_key::radix_parallel::scalar_div_mod::{
choose_multiplier, SignedReciprocable,
};
use crate::integer::server_key::{MiniUnsignedInteger, Reciprocable, ScalarMultiplier};
use crate::prelude::{CastFrom, CastInto};
@@ -32,6 +36,21 @@ impl CudaServerKey {
result
}
fn signed_scalar_mul_high<Scalar>(
&self,
lhs: &CudaSignedRadixCiphertext,
rhs: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
{
let num_blocks = lhs.as_ref().d_blocks.lwe_ciphertext_count().0;
let mut result = self.extend_radix_with_sign_msb(lhs, num_blocks, streams);
self.scalar_mul_assign(&mut result, rhs, streams);
self.trim_radix_blocks_lsb(&result, num_blocks, streams)
}
/// Computes homomorphically a division between a ciphertext and a scalar.
///
/// This function computes the operation without checking if it exceeds the capacity of the
@@ -403,4 +422,262 @@ impl CudaServerKey {
self.unchecked_scalar_rem(numerator, divisor, streams)
}
pub fn unchecked_signed_scalar_div<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
assert_ne!(divisor, Scalar::ZERO, "attempt to divide by 0");
let numerator_bits = self.message_modulus.0.ilog2()
* numerator.ciphertext.d_blocks.lwe_ciphertext_count().0 as u32;
assert!(
Scalar::BITS >= numerator_bits as usize,
"The scalar divisor type must have a number of bits that is\
>= to the number of bits encrypted in the ciphertext"
);
// wrappings_abs returns Scalar::MIN when its input is Scalar::MIN (since in signed numbers
// Scalar::MIN's absolute value cannot be represented.
// However, casting Scalar::MIN to signed value will give the correct abs value
// If Scalar and Scalar::Unsigned have the same number of bits
let absolute_divisor = Scalar::Unsigned::cast_from(divisor.wrapping_abs());
if absolute_divisor == Scalar::Unsigned::ONE {
// Strangely, the paper says: Issue q = d;
return if divisor < Scalar::ZERO {
// quotient = -quotient;
self.neg(numerator, streams)
} else {
numerator.duplicate(streams)
};
}
let chosen_multiplier =
choose_multiplier(absolute_divisor, numerator_bits - 1, numerator_bits);
if chosen_multiplier.l >= numerator_bits {
return self.create_trivial_zero_radix(
numerator.ciphertext.d_blocks.lwe_ciphertext_count().0,
streams,
);
}
let quotient;
if absolute_divisor == (Scalar::Unsigned::ONE << chosen_multiplier.l as usize) {
// Issue q = SRA(n + SRL(SRA(n, l 1), N l), l);
let l = chosen_multiplier.l;
// SRA(n, l 1)
let mut tmp = self.unchecked_scalar_right_shift(numerator, l - 1, streams);
// SRL(SRA(n, l 1), N l)
unsafe {
self.unchecked_scalar_right_shift_logical_assign_async(
&mut tmp,
(numerator_bits - l) as usize,
streams,
);
}
streams.synchronize();
// n + SRL(SRA(n, l 1), N l)
self.add_assign(&mut tmp, numerator, streams);
// SRA(n + SRL(SRA(n, l 1), N l), l);
quotient = self.unchecked_scalar_right_shift(&tmp, l, streams);
} else if chosen_multiplier.multiplier
< (<Scalar::Unsigned as Reciprocable>::DoublePrecision::ONE << (numerator_bits - 1))
{
// in the condition above works (it makes more values take this branch,
// but results still seemed correct)
// multiplier is less than the max possible value of Scalar
// Issue q = SRA(MULSH(m, n), shpost) XSIGN(n);
let (mut tmp, xsign) = rayon::join(
move || {
// MULSH(m, n)
let mut tmp = self.signed_scalar_mul_high(
numerator,
chosen_multiplier.multiplier,
streams,
);
// SRA(MULSH(m, n), shpost)
unsafe {
self.unchecked_scalar_right_shift_assign_async(
&mut tmp,
chosen_multiplier.shift_post,
streams,
);
}
streams.synchronize();
tmp
},
|| {
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
// It is equivalent to SRA(x, N 1)
self.unchecked_scalar_right_shift(numerator, numerator_bits - 1, streams)
},
);
self.sub_assign(&mut tmp, &xsign, streams);
quotient = tmp;
} else {
// Issue q = SRA(n + MULSH(m 2^N , n), shpost) XSIGN(n);
// Note from the paper: m - 2^N is negative
let (mut tmp, xsign) = rayon::join(
move || {
// The subtraction may overflow.
// We then cast the result to a signed type.
// Overall, this will work fine due to two's complement representation
let cst = chosen_multiplier.multiplier
- (<Scalar::Unsigned as Reciprocable>::DoublePrecision::ONE
<< numerator_bits);
let cst = Scalar::DoublePrecision::cast_from(cst);
// MULSH(m - 2^N, n)
let mut tmp = self.signed_scalar_mul_high(numerator, cst, streams);
// n + MULSH(m 2^N , n)
self.add_assign(&mut tmp, numerator, streams);
// SRA(n + MULSH(m - 2^N, n), shpost)
tmp = self.unchecked_scalar_right_shift(
&tmp,
chosen_multiplier.shift_post,
streams,
);
tmp
},
|| {
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
// It is equivalent to SRA(x, N 1)
self.unchecked_scalar_right_shift(numerator, numerator_bits - 1, streams)
},
);
self.sub_assign(&mut tmp, &xsign, streams);
quotient = tmp;
}
if divisor < Scalar::ZERO {
self.neg(&quotient, streams)
} else {
quotient
}
}
pub fn signed_scalar_div<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let mut tmp_numerator;
let numerator = if numerator.block_carries_are_empty() {
numerator
} else {
unsafe {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
}
&tmp_numerator
};
self.unchecked_signed_scalar_div(numerator, divisor, streams)
}
pub fn unchecked_signed_scalar_div_rem<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let quotient = self.unchecked_signed_scalar_div(numerator, divisor, streams);
// remainder = numerator - (quotient * divisor)
let tmp = self.unchecked_scalar_mul(&quotient, divisor, streams);
let remainder = self.sub(numerator, &tmp, streams);
(quotient, remainder)
}
pub fn signed_scalar_div_rem<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let mut tmp_numerator;
let numerator = if numerator.block_carries_are_empty() {
numerator
} else {
unsafe {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
}
&tmp_numerator
};
self.unchecked_signed_scalar_div_rem(numerator, divisor, streams)
}
pub fn unchecked_signed_scalar_rem<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let (_, remainder) = self.unchecked_signed_scalar_div_rem(numerator, divisor, streams);
remainder
}
pub fn signed_scalar_rem<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let mut tmp_numerator;
let numerator = if numerator.block_carries_are_empty() {
numerator
} else {
unsafe {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
}
&tmp_numerator
};
self.unchecked_signed_scalar_rem(numerator, divisor, streams)
}
}

View File

@@ -565,4 +565,76 @@ impl CudaServerKey {
};
stream.synchronize();
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_right_shift_logical_assign_async<Scalar, T>(
&self,
ct: &mut T,
shift: Scalar,
stream: &CudaStreams,
) where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
unchecked_scalar_logical_right_shift_integer_radix_kb_assign_async(
stream,
&mut ct.as_mut().d_blocks.0.d_vec,
u32::cast_from(shift),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
unchecked_scalar_logical_right_shift_integer_radix_kb_assign_async(
stream,
&mut ct.as_mut().d_blocks.0.d_vec,
u32::cast_from(shift),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
);
}
}
}
}

View File

@@ -9,6 +9,7 @@ pub(crate) mod test_rotate;
pub(crate) mod test_scalar_add;
pub(crate) mod test_scalar_bitwise_op;
pub(crate) mod test_scalar_comparison;
pub(crate) mod test_scalar_div_mod;
pub(crate) mod test_scalar_mul;
pub(crate) mod test_scalar_rotate;
pub(crate) mod test_scalar_shift;
@@ -265,6 +266,46 @@ where
}
}
/// For unchecked/default binary functions with one scalar input and two encrypted outputs
impl<'a, F>
FunctionExecutor<
(&'a SignedRadixCiphertext, i64),
(SignedRadixCiphertext, SignedRadixCiphertext),
> for GpuFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
i64,
&CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a SignedRadixCiphertext, i64),
) -> (SignedRadixCiphertext, SignedRadixCiphertext) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let (gpu_result_1, gpu_result_2) =
(self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
(
gpu_result_1.to_signed_radix_ciphertext(&context.streams),
gpu_result_2.to_signed_radix_ciphertext(&context.streams),
)
}
}
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), BooleanBlock>
for GpuFunctionExecutor<F>
where

View File

@@ -0,0 +1,16 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parametrized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_div_mod::signed_unchecked_scalar_div_rem_test;
use crate::shortint::parameters::*;
create_gpu_parametrized_test!(integer_signed_unchecked_scalar_div_rem);
fn integer_signed_unchecked_scalar_div_rem<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::signed_scalar_div_rem);
signed_unchecked_scalar_div_rem_test(param, executor);
}

View File

@@ -135,6 +135,7 @@ pub trait SignedReciprocable:
type DoublePrecision: DecomposableInto<u8>
+ ScalarMultiplier
+ CastFrom<<Self::Unsigned as Reciprocable>::DoublePrecision>
+ CastInto<u64>
+ std::fmt::Debug;
fn wrapping_abs(self) -> Self;

View File

@@ -10,6 +10,7 @@ pub(crate) mod test_rotate;
pub(crate) mod test_scalar_add;
pub(crate) mod test_scalar_bitwise_op;
pub(crate) mod test_scalar_comparison;
pub(crate) mod test_scalar_div_mod;
pub(crate) mod test_scalar_mul;
pub(crate) mod test_scalar_rotate;
pub(crate) mod test_scalar_shift;
@@ -488,154 +489,8 @@ fn integer_signed_default_absolute_value(param: impl Into<PBSParameters>) {
//================================================================================
// Unchecked Scalar Tests
//================================================================================
create_parametrized_test!(integer_signed_unchecked_scalar_div_rem);
create_parametrized_test!(integer_signed_unchecked_scalar_div_rem_floor);
fn integer_signed_unchecked_scalar_div_rem(param: impl Into<PBSParameters>) {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let mut rng = rand::thread_rng();
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
{
let clear_0 = rng.gen::<i64>() % modulus;
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let result = std::panic::catch_unwind(|| {
let _ = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, 0);
});
assert!(result.is_err(), "Division by zero did not panic");
}
// check when scalar is out of ciphertext MIN..=MAX
for d in [
rng.gen_range(i64::MIN..-modulus),
rng.gen_range(modulus..=i64::MAX),
] {
for numerator in [rng.gen_range(-modulus..=0), rng.gen_range(0..modulus)] {
let ctxt_0 = cks.encrypt_signed_radix(numerator, NB_CTXT);
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
assert_eq!(q, signed_div_under_modulus(numerator, d, modulus));
assert_eq!(r, signed_rem_under_modulus(numerator, d, modulus));
}
}
// The algorithm has a special case for when divisor is 1 or -1
for d in [1i64, -1i64] {
let clear_0 = rng.gen::<i64>() % modulus;
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
assert_eq!(q, signed_div_under_modulus(clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(clear_0, d, modulus));
}
// 3 / -3 takes the second branch in the if else if series
for d in [3, -3] {
{
let neg_clear_0 = rng.gen_range(-modulus..=0);
let ctxt_0 = cks.encrypt_signed_radix(neg_clear_0, NB_CTXT);
println!("{neg_clear_0} / {d}");
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
}
{
let pos_clear_0 = rng.gen_range(0..modulus);
let ctxt_0 = cks.encrypt_signed_radix(pos_clear_0, NB_CTXT);
println!("{pos_clear_0} / {d}");
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
}
}
// Param 1_1 cannot do this, with our NB_CTXT
if modulus >= 43 {
// For param_2_2 this will take the third branch in the if else if series
for d in [-89, 89] {
{
let neg_clear_0 = rng.gen_range(-modulus..=0);
let ctxt_0 = cks.encrypt_signed_radix(neg_clear_0, NB_CTXT);
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
}
{
let pos_clear_0 = rng.gen_range(0..modulus);
let ctxt_0 = cks.encrypt_signed_radix(pos_clear_0, NB_CTXT);
println!("{pos_clear_0} / {d}");
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
}
}
// For param_2_2 this will take the first branch
for (clear_0, clear_1) in [(43, 8), (43, -8), (-43, 8), (-43, -8)] {
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, clear_1);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
assert_eq!(q, signed_div_under_modulus(clear_0, clear_1, modulus));
assert_eq!(r, signed_rem_under_modulus(clear_0, clear_1, modulus));
}
}
for d in [-modulus, modulus - 1] {
{
let neg_clear_0 = rng.gen_range(-modulus..=0);
let ctxt_0 = cks.encrypt_signed_radix(neg_clear_0, NB_CTXT);
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
}
{
let pos_clear_0 = rng.gen_range(0..modulus);
let ctxt_0 = cks.encrypt_signed_radix(pos_clear_0, NB_CTXT);
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, d);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
}
}
let lhs_values = random_signed_value_under_modulus::<6>(&mut rng, modulus);
let rhs_values = random_non_zero_signed_value_under_modulus::<6>(&mut rng, modulus);
for (clear_lhs, clear_rhs) in iproduct!(lhs_values, rhs_values) {
let ctxt_0 = cks.encrypt_signed_radix(clear_lhs, NB_CTXT);
let (q_res, r_res) = sks.unchecked_signed_scalar_div_rem_parallelized(&ctxt_0, clear_rhs);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
assert_eq!(q, signed_div_under_modulus(clear_lhs, clear_rhs, modulus));
assert_eq!(r, signed_rem_under_modulus(clear_lhs, clear_rhs, modulus));
}
}
fn integer_signed_unchecked_scalar_div_rem_floor(param: impl Into<PBSParameters>) {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

View File

@@ -0,0 +1,181 @@
use crate::integer::ciphertext::SignedRadixCiphertext;
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::server_key::radix_parallel::tests_signed::{
random_non_zero_signed_value_under_modulus, random_signed_value_under_modulus,
signed_div_under_modulus, signed_rem_under_modulus, NB_CTXT,
};
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
use crate::integer::tests::create_parametrized_test;
use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey};
#[cfg(tarpaulin)]
use crate::shortint::parameters::coverage_parameters::*;
use crate::shortint::parameters::*;
use itertools::iproduct;
use rand::prelude::*;
use std::sync::Arc;
create_parametrized_test!(integer_signed_unchecked_scalar_div_rem);
fn integer_signed_unchecked_scalar_div_rem<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor =
CpuFunctionExecutor::new(&ServerKey::unchecked_signed_scalar_div_rem_parallelized);
signed_unchecked_scalar_div_rem_test(param, executor);
}
pub(crate) fn signed_unchecked_scalar_div_rem_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<
(&'a SignedRadixCiphertext, i64),
(SignedRadixCiphertext, SignedRadixCiphertext),
> + std::panic::UnwindSafe,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let mut rng = rand::thread_rng();
let cks = RadixClientKey::from((cks, NB_CTXT));
let sks = Arc::new(sks);
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
executor.setup(&cks, sks.clone());
// check when scalar is out of ciphertext MIN..=MAX
for d in [
rng.gen_range(i64::MIN..-modulus),
rng.gen_range(modulus..=i64::MAX),
] {
for numerator in [rng.gen_range(-modulus..=0), rng.gen_range(0..modulus)] {
let ctxt_0 = cks.encrypt_signed(numerator);
let (q_res, r_res) = executor.execute((&ctxt_0, d));
let q: i64 = cks.decrypt_signed(&q_res);
let r: i64 = cks.decrypt_signed(&r_res);
assert_eq!(q, signed_div_under_modulus(numerator, d, modulus));
assert_eq!(r, signed_rem_under_modulus(numerator, d, modulus));
}
}
// The algorithm has a special case for when divisor is 1 or -1
for d in [1i64, -1i64] {
let clear_0 = rng.gen::<i64>() % modulus;
let ctxt_0 = cks.encrypt_signed(clear_0);
let (q_res, r_res) = executor.execute((&ctxt_0, d));
let q: i64 = cks.decrypt_signed(&q_res);
let r: i64 = cks.decrypt_signed(&r_res);
assert_eq!(q, signed_div_under_modulus(clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(clear_0, d, modulus));
}
// 3 / -3 takes the second branch in the if else if series
for d in [3, -3] {
{
let neg_clear_0 = rng.gen_range(-modulus..=0);
let ctxt_0 = cks.encrypt_signed(neg_clear_0);
println!("{neg_clear_0} / {d}");
let (q_res, r_res) = executor.execute((&ctxt_0, d));
let q: i64 = cks.decrypt_signed(&q_res);
let r: i64 = cks.decrypt_signed(&r_res);
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
}
{
let pos_clear_0 = rng.gen_range(0..modulus);
let ctxt_0 = cks.encrypt_signed(pos_clear_0);
println!("{pos_clear_0} / {d}");
let (q_res, r_res) = executor.execute((&ctxt_0, d));
let q: i64 = cks.decrypt_signed(&q_res);
let r: i64 = cks.decrypt_signed(&r_res);
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
}
}
// Param 1_1 cannot do this, with our NB_CTXT
if modulus >= 43 {
// For param_2_2 this will take the third branch in the if else if series
for d in [-89, 89] {
{
let neg_clear_0 = rng.gen_range(-modulus..=0);
let ctxt_0 = cks.encrypt_signed(neg_clear_0);
let (q_res, r_res) = executor.execute((&ctxt_0, d));
let q: i64 = cks.decrypt_signed(&q_res);
let r: i64 = cks.decrypt_signed(&r_res);
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
}
{
let pos_clear_0 = rng.gen_range(0..modulus);
let ctxt_0 = cks.encrypt_signed(pos_clear_0);
println!("{pos_clear_0} / {d}");
let (q_res, r_res) = executor.execute((&ctxt_0, d));
let q: i64 = cks.decrypt_signed(&q_res);
let r: i64 = cks.decrypt_signed(&r_res);
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
}
}
// For param_2_2 this will take the first branch
for (clear_0, clear_1) in [(43, 8), (43, -8), (-43, 8), (-43, -8)] {
let ctxt_0 = cks.encrypt_signed(clear_0);
let (q_res, r_res) = executor.execute((&ctxt_0, clear_1));
let q: i64 = cks.decrypt_signed(&q_res);
let r: i64 = cks.decrypt_signed(&r_res);
assert_eq!(q, signed_div_under_modulus(clear_0, clear_1, modulus));
assert_eq!(r, signed_rem_under_modulus(clear_0, clear_1, modulus));
}
}
for d in [-modulus, modulus - 1] {
{
let neg_clear_0 = rng.gen_range(-modulus..=0);
let ctxt_0 = cks.encrypt_signed(neg_clear_0);
let (q_res, r_res) = executor.execute((&ctxt_0, d));
let q: i64 = cks.decrypt_signed(&q_res);
let r: i64 = cks.decrypt_signed(&r_res);
assert_eq!(q, signed_div_under_modulus(neg_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(neg_clear_0, d, modulus));
}
{
let pos_clear_0 = rng.gen_range(0..modulus);
let ctxt_0 = cks.encrypt_signed(pos_clear_0);
let (q_res, r_res) = executor.execute((&ctxt_0, d));
let q: i64 = cks.decrypt_signed(&q_res);
let r: i64 = cks.decrypt_signed(&r_res);
assert_eq!(q, signed_div_under_modulus(pos_clear_0, d, modulus));
assert_eq!(r, signed_rem_under_modulus(pos_clear_0, d, modulus));
}
}
let lhs_values = random_signed_value_under_modulus::<6>(&mut rng, modulus);
let rhs_values = random_non_zero_signed_value_under_modulus::<6>(&mut rng, modulus);
for (clear_lhs, clear_rhs) in iproduct!(lhs_values, rhs_values) {
let ctxt_0 = cks.encrypt_signed(clear_lhs);
let (q_res, r_res) = executor.execute((&ctxt_0, clear_rhs));
let q: i64 = cks.decrypt_signed(&q_res);
let r: i64 = cks.decrypt_signed(&r_res);
assert_eq!(q, signed_div_under_modulus(clear_lhs, clear_rhs, modulus));
assert_eq!(r, signed_rem_under_modulus(clear_lhs, clear_rhs, modulus));
}
// Do this test last, so we can move the executor into the closure
let result = std::panic::catch_unwind(move || {
let numerator = sks.create_trivial_radix(1, NB_CTXT);
executor.execute((&numerator, 0i64));
});
assert!(result.is_err(), "division by zero should panic");
}