mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(gpu): implement scalar eq and ne
This commit is contained in:
@@ -467,10 +467,10 @@ template <typename Torus> struct int_sc_prop_memory {
|
||||
// create lut objects
|
||||
luts_array = new int_radix_lut<Torus>(stream, params, 2, num_radix_blocks,
|
||||
allocate_gpu_memory);
|
||||
luts_carry_propagation_sum = new struct int_radix_lut<Torus>(
|
||||
luts_carry_propagation_sum = new int_radix_lut<Torus>(
|
||||
stream, params, 1, num_radix_blocks, luts_array);
|
||||
message_acc = new struct int_radix_lut<Torus>(stream, params, 1,
|
||||
num_radix_blocks, luts_array);
|
||||
message_acc = new int_radix_lut<Torus>(stream, params, 1, num_radix_blocks,
|
||||
luts_array);
|
||||
|
||||
auto lut_does_block_generate_carry = luts_array->get_lut(0);
|
||||
auto lut_does_block_generate_or_propagate = luts_array->get_lut(1);
|
||||
@@ -935,6 +935,8 @@ template <typename Torus> struct int_comparison_eq_buffer {
|
||||
|
||||
int_are_all_block_true_buffer<Torus> *are_all_block_true_buffer;
|
||||
|
||||
int_radix_lut<Torus> *scalar_comparison_luts;
|
||||
|
||||
int_comparison_eq_buffer(cuda_stream_t *stream, COMPARISON_TYPE op,
|
||||
int_radix_params params, uint32_t num_radix_blocks,
|
||||
bool allocate_gpu_memory) {
|
||||
@@ -977,6 +979,22 @@ template <typename Torus> struct int_comparison_eq_buffer {
|
||||
stream, is_non_zero_lut->lut, params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
is_non_zero_lut_f);
|
||||
|
||||
// Scalar may have up to num_radix_blocks blocks
|
||||
scalar_comparison_luts = new int_radix_lut<Torus>(
|
||||
stream, params, total_modulus, num_radix_blocks, allocate_gpu_memory);
|
||||
|
||||
for (int i = 0; i < total_modulus; i++) {
|
||||
auto lut_f = [i, operator_f](Torus x) -> Torus {
|
||||
return operator_f(i, x);
|
||||
};
|
||||
|
||||
Torus *lut = scalar_comparison_luts->lut +
|
||||
i * (params.glwe_dimension + 1) * params.polynomial_size;
|
||||
generate_device_accumulator<Torus>(
|
||||
stream, lut, params.glwe_dimension, params.polynomial_size,
|
||||
params.message_modulus, params.carry_modulus, lut_f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -988,6 +1006,9 @@ template <typename Torus> struct int_comparison_eq_buffer {
|
||||
|
||||
are_all_block_true_buffer->release(stream);
|
||||
delete are_all_block_true_buffer;
|
||||
|
||||
scalar_comparison_luts->release(stream);
|
||||
delete scalar_comparison_luts;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1064,14 +1085,8 @@ template <typename Torus> struct int_comparison_diff_buffer {
|
||||
|
||||
std::function<Torus(Torus)> operator_f;
|
||||
|
||||
int_radix_lut<Torus> *is_zero_lut;
|
||||
|
||||
int_tree_sign_reduction_buffer<Torus> *tree_buffer;
|
||||
|
||||
// Used for scalar comparisons
|
||||
cuda_stream_t *lsb_stream;
|
||||
cuda_stream_t *msb_stream;
|
||||
|
||||
int_comparison_diff_buffer(cuda_stream_t *stream, COMPARISON_TYPE op,
|
||||
int_radix_params params, uint32_t num_radix_blocks,
|
||||
bool allocate_gpu_memory) {
|
||||
@@ -1095,8 +1110,6 @@ template <typename Torus> struct int_comparison_diff_buffer {
|
||||
};
|
||||
|
||||
if (allocate_gpu_memory) {
|
||||
lsb_stream = cuda_create_stream(stream->gpu_index);
|
||||
msb_stream = cuda_create_stream(stream->gpu_index);
|
||||
|
||||
Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus);
|
||||
|
||||
@@ -1106,36 +1119,17 @@ template <typename Torus> struct int_comparison_diff_buffer {
|
||||
tmp_packed_right =
|
||||
(Torus *)cuda_malloc_async(big_size * (num_radix_blocks / 2), stream);
|
||||
|
||||
// LUTs
|
||||
uint32_t total_modulus = params.message_modulus * params.carry_modulus;
|
||||
auto is_zero_f = [total_modulus](Torus x) -> Torus {
|
||||
return (x % total_modulus) == 0;
|
||||
};
|
||||
|
||||
is_zero_lut = new int_radix_lut<Torus>(
|
||||
stream, params, 1, num_radix_blocks, allocate_gpu_memory);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
stream, is_zero_lut->lut, params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
is_zero_f);
|
||||
|
||||
tree_buffer = new int_tree_sign_reduction_buffer<Torus>(
|
||||
stream, operator_f, params, num_radix_blocks, allocate_gpu_memory);
|
||||
}
|
||||
}
|
||||
|
||||
void release(cuda_stream_t *stream) {
|
||||
is_zero_lut->release(stream);
|
||||
delete is_zero_lut;
|
||||
tree_buffer->release(stream);
|
||||
delete tree_buffer;
|
||||
|
||||
cuda_drop_async(tmp_packed_left, stream);
|
||||
cuda_drop_async(tmp_packed_right, stream);
|
||||
|
||||
cuda_destroy_stream(lsb_stream);
|
||||
cuda_destroy_stream(msb_stream);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1148,15 +1142,24 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
int_radix_lut<Torus> *cleaning_lut;
|
||||
std::function<Torus(Torus)> cleaning_lut_f;
|
||||
|
||||
int_radix_lut<Torus> *is_zero_lut;
|
||||
|
||||
int_comparison_eq_buffer<Torus> *eq_buffer;
|
||||
int_comparison_diff_buffer<Torus> *diff_buffer;
|
||||
|
||||
Torus *tmp_block_comparisons;
|
||||
Torus *tmp_lwe_array_out;
|
||||
|
||||
// Scalar EQ / NE
|
||||
Torus *tmp_packed_input;
|
||||
|
||||
// Max Min
|
||||
Torus *tmp_lwe_array_out;
|
||||
int_cmux_buffer<Torus> *cmux_buffer;
|
||||
|
||||
// Used for scalar comparisons
|
||||
cuda_stream_t *lsb_stream;
|
||||
cuda_stream_t *msb_stream;
|
||||
|
||||
int_comparison_buffer(cuda_stream_t *stream, COMPARISON_TYPE op,
|
||||
int_radix_params params, uint32_t num_radix_blocks,
|
||||
bool allocate_gpu_memory) {
|
||||
@@ -1166,10 +1169,17 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
cleaning_lut_f = [](Torus x) -> Torus { return x; };
|
||||
|
||||
if (allocate_gpu_memory) {
|
||||
lsb_stream = cuda_create_stream(stream->gpu_index);
|
||||
msb_stream = cuda_create_stream(stream->gpu_index);
|
||||
|
||||
tmp_lwe_array_out = (Torus *)cuda_malloc_async(
|
||||
(params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus),
|
||||
stream);
|
||||
|
||||
tmp_packed_input = (Torus *)cuda_malloc_async(
|
||||
(params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus),
|
||||
stream);
|
||||
|
||||
// Block comparisons
|
||||
tmp_block_comparisons = (Torus *)cuda_malloc_async(
|
||||
(params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus),
|
||||
@@ -1184,6 +1194,19 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
cleaning_lut_f);
|
||||
|
||||
uint32_t total_modulus = params.message_modulus * params.carry_modulus;
|
||||
auto is_zero_f = [total_modulus](Torus x) -> Torus {
|
||||
return (x % total_modulus) == 0;
|
||||
};
|
||||
|
||||
is_zero_lut = new int_radix_lut<Torus>(
|
||||
stream, params, 1, num_radix_blocks, allocate_gpu_memory);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
stream, is_zero_lut->lut, params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus, params.carry_modulus,
|
||||
is_zero_f);
|
||||
|
||||
switch (op) {
|
||||
case COMPARISON_TYPE::MAX:
|
||||
case COMPARISON_TYPE::MIN:
|
||||
@@ -1227,8 +1250,14 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
break;
|
||||
}
|
||||
cleaning_lut->release(stream);
|
||||
is_zero_lut->release(stream);
|
||||
delete is_zero_lut;
|
||||
cuda_drop_async(tmp_lwe_array_out, stream);
|
||||
cuda_drop_async(tmp_block_comparisons, stream);
|
||||
cuda_drop_async(tmp_packed_input, stream);
|
||||
|
||||
cuda_destroy_stream(lsb_stream);
|
||||
cuda_destroy_stream(msb_stream);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ void cuda_comparison_integer_radix_ciphertext_kb_64(
|
||||
static_cast<uint64_t *>(ksk), lwe_ciphertext_count);
|
||||
break;
|
||||
default:
|
||||
printf("Not implemented\n");
|
||||
PANIC("Cuda error: integer operation not supported");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,13 @@ __host__ void accumulate_all_blocks(cuda_stream_t *stream, Torus *output,
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
/* This takes an array of lwe ciphertexts, where each is an encryption of
|
||||
* either 0 or 1.
|
||||
*
|
||||
* It writes in lwe_array_out a single lwe ciphertext encrypting 1 if all input
|
||||
* blocks are 1 otherwise the block encrypts 0
|
||||
*
|
||||
*/
|
||||
template <typename Torus>
|
||||
__host__ void
|
||||
are_all_comparisons_block_true(cuda_stream_t *stream, Torus *lwe_array_out,
|
||||
@@ -122,6 +129,60 @@ are_all_comparisons_block_true(cuda_stream_t *stream, Torus *lwe_array_out,
|
||||
}
|
||||
}
|
||||
|
||||
/* This takes an array of lwe ciphertexts, where each is an encryption of
|
||||
* either 0 or 1.
|
||||
*
|
||||
* It writes in lwe_array_out a single lwe ciphertext encrypting 1 if at least
|
||||
* one input ciphertext encrypts 1 otherwise encrypts 0
|
||||
*/
|
||||
template <typename Torus>
|
||||
__host__ void is_at_least_one_comparisons_block_true(
|
||||
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *bsk, Torus *ksk,
|
||||
uint32_t num_radix_blocks) {
|
||||
auto params = mem_ptr->params;
|
||||
auto big_lwe_dimension = params.big_lwe_dimension;
|
||||
auto message_modulus = params.message_modulus;
|
||||
auto carry_modulus = params.carry_modulus;
|
||||
|
||||
auto buffer = mem_ptr->eq_buffer->are_all_block_true_buffer;
|
||||
|
||||
uint32_t total_modulus = message_modulus * carry_modulus;
|
||||
uint32_t max_value = total_modulus - 1;
|
||||
|
||||
cuda_memcpy_async_gpu_to_gpu(
|
||||
lwe_array_out, lwe_array_in,
|
||||
num_radix_blocks * (big_lwe_dimension + 1) * sizeof(Torus), stream);
|
||||
|
||||
uint32_t remaining_blocks = num_radix_blocks;
|
||||
while (remaining_blocks > 1) {
|
||||
// Split in max_value chunks
|
||||
uint32_t chunk_length = std::min(max_value, remaining_blocks);
|
||||
int num_chunks = remaining_blocks / chunk_length;
|
||||
|
||||
// Since all blocks encrypt either 0 or 1, we can sum max_value of them
|
||||
// as in the worst case we will be adding `max_value` ones
|
||||
auto input_blocks = lwe_array_out;
|
||||
auto accumulator = buffer->tmp_block_accumulated;
|
||||
for (int i = 0; i < num_chunks; i++) {
|
||||
accumulate_all_blocks(stream, accumulator, input_blocks,
|
||||
big_lwe_dimension, chunk_length);
|
||||
|
||||
accumulator += (big_lwe_dimension + 1);
|
||||
remaining_blocks -= (chunk_length - 1);
|
||||
input_blocks += (big_lwe_dimension + 1) * chunk_length;
|
||||
}
|
||||
accumulator = buffer->tmp_block_accumulated;
|
||||
|
||||
// Selects a LUT
|
||||
int_radix_lut<Torus> *lut = mem_ptr->eq_buffer->is_non_zero_lut;
|
||||
|
||||
// Applies the LUT
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
stream, lwe_array_out, accumulator, bsk, ksk, num_chunks, lut);
|
||||
}
|
||||
}
|
||||
|
||||
// This takes an input slice of blocks.
|
||||
//
|
||||
// Each block can encrypt any value as long as its < message_modulus.
|
||||
@@ -145,7 +206,7 @@ template <typename Torus>
|
||||
__host__ void host_compare_with_zero_equality(
|
||||
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
|
||||
int_comparison_buffer<Torus> *mem_ptr, void *bsk, Torus *ksk,
|
||||
int32_t num_radix_blocks) {
|
||||
int32_t num_radix_blocks, int_radix_lut<Torus> *zero_comparison) {
|
||||
|
||||
auto params = mem_ptr->params;
|
||||
auto big_lwe_dimension = params.big_lwe_dimension;
|
||||
@@ -175,7 +236,6 @@ __host__ void host_compare_with_zero_equality(
|
||||
num_sum_blocks = 1;
|
||||
} else {
|
||||
uint32_t remainder_blocks = num_radix_blocks;
|
||||
|
||||
auto sum_i = sum;
|
||||
auto chunk = lwe_array_in;
|
||||
while (remainder_blocks > 1) {
|
||||
@@ -194,9 +254,8 @@ __host__ void host_compare_with_zero_equality(
|
||||
}
|
||||
}
|
||||
|
||||
auto is_equal_to_zero_lut = mem_ptr->diff_buffer->is_zero_lut;
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
stream, sum, sum, bsk, ksk, num_sum_blocks, is_equal_to_zero_lut);
|
||||
stream, sum, sum, bsk, ksk, num_sum_blocks, zero_comparison);
|
||||
are_all_comparisons_block_true(stream, lwe_array_out, sum, mem_ptr, bsk, ksk,
|
||||
num_sum_blocks);
|
||||
|
||||
|
||||
@@ -8,17 +8,14 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
|
||||
int_comparison_buffer<uint64_t> *buffer =
|
||||
(int_comparison_buffer<uint64_t> *)mem_ptr;
|
||||
switch (buffer->op) {
|
||||
// case EQ:
|
||||
// case NE:
|
||||
// host_integer_radix_equality_check_kb<uint64_t>(
|
||||
// stream, static_cast<uint64_t *>(lwe_array_out),
|
||||
// static_cast<uint64_t *>(lwe_array_1),
|
||||
// static_cast<uint64_t *>(lwe_array_2), buffer, bsk,
|
||||
// static_cast<uint64_t *>(ksk), glwe_dimension, polynomial_size,
|
||||
// big_lwe_dimension, small_lwe_dimension, ks_level, ks_base_log,
|
||||
// pbs_level, pbs_base_log, grouping_factor, lwe_ciphertext_count,
|
||||
// message_modulus, carry_modulus);
|
||||
// break;
|
||||
case EQ:
|
||||
case NE:
|
||||
host_integer_radix_scalar_equality_check_kb<uint64_t>(
|
||||
stream, static_cast<uint64_t *>(lwe_array_out),
|
||||
static_cast<uint64_t *>(lwe_array_in),
|
||||
static_cast<uint64_t *>(scalar_blocks), buffer, bsk,
|
||||
static_cast<uint64_t *>(ksk), lwe_ciphertext_count, num_scalar_blocks);
|
||||
break;
|
||||
case GT:
|
||||
case GE:
|
||||
case LT:
|
||||
@@ -39,6 +36,6 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
|
||||
static_cast<uint64_t *>(ksk), lwe_ciphertext_count, num_scalar_blocks);
|
||||
break;
|
||||
default:
|
||||
printf("Not implemented\n");
|
||||
PANIC("Cuda error: integer operation not supported");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,9 +46,9 @@ __host__ void host_integer_radix_scalar_difference_check_kb(
|
||||
if (total_num_scalar_blocks == 0) {
|
||||
// We only have to compare blocks with zero
|
||||
// means scalar is zero
|
||||
host_compare_with_zero_equality(stream, mem_ptr->tmp_lwe_array_out,
|
||||
lwe_array_in, mem_ptr, bsk, ksk,
|
||||
total_num_radix_blocks);
|
||||
host_compare_with_zero_equality(
|
||||
stream, mem_ptr->tmp_lwe_array_out, lwe_array_in, mem_ptr, bsk, ksk,
|
||||
total_num_radix_blocks, mem_ptr->is_zero_lut);
|
||||
|
||||
auto scalar_last_leaf_lut_f = [sign_handler_f](Torus x) -> Torus {
|
||||
x = (x == 1 ? IS_EQUAL : IS_SUPERIOR);
|
||||
@@ -84,8 +84,8 @@ __host__ void host_integer_radix_scalar_difference_check_kb(
|
||||
auto lwe_array_msb_out = lwe_array_lsb_out + big_lwe_size;
|
||||
|
||||
cuda_synchronize_stream(stream);
|
||||
auto lsb_stream = diff_buffer->lsb_stream;
|
||||
auto msb_stream = diff_buffer->msb_stream;
|
||||
auto lsb_stream = mem_ptr->lsb_stream;
|
||||
auto msb_stream = mem_ptr->msb_stream;
|
||||
|
||||
#pragma omp parallel sections
|
||||
{
|
||||
@@ -128,8 +128,8 @@ __host__ void host_integer_radix_scalar_difference_check_kb(
|
||||
//////////////
|
||||
// msb
|
||||
host_compare_with_zero_equality(msb_stream, lwe_array_msb_out, msb,
|
||||
mem_ptr, bsk, ksk,
|
||||
num_msb_radix_blocks);
|
||||
mem_ptr, bsk, ksk, num_msb_radix_blocks,
|
||||
mem_ptr->is_zero_lut);
|
||||
}
|
||||
}
|
||||
cuda_synchronize_stream(lsb_stream);
|
||||
@@ -210,16 +210,7 @@ scalar_compare_radix_blocks_kb(cuda_stream_t *stream, Torus *lwe_array_out,
|
||||
Torus *ksk, uint32_t num_radix_blocks) {
|
||||
|
||||
auto params = mem_ptr->params;
|
||||
auto pbs_type = params.pbs_type;
|
||||
auto big_lwe_dimension = params.big_lwe_dimension;
|
||||
auto small_lwe_dimension = params.small_lwe_dimension;
|
||||
auto ks_level = params.ks_level;
|
||||
auto ks_base_log = params.ks_base_log;
|
||||
auto pbs_level = params.pbs_level;
|
||||
auto pbs_base_log = params.pbs_base_log;
|
||||
auto glwe_dimension = params.glwe_dimension;
|
||||
auto polynomial_size = params.polynomial_size;
|
||||
auto grouping_factor = params.grouping_factor;
|
||||
auto message_modulus = params.message_modulus;
|
||||
auto carry_modulus = params.carry_modulus;
|
||||
|
||||
@@ -295,4 +286,115 @@ __host__ void host_integer_radix_scalar_maxmin_kb(
|
||||
stream, lwe_array_out, mem_ptr->tmp_lwe_array_out, lwe_array_left,
|
||||
lwe_array_right, mem_ptr->cmux_buffer, bsk, ksk, total_num_radix_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void host_integer_radix_scalar_equality_check_kb(
|
||||
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
|
||||
Torus *scalar_blocks, int_comparison_buffer<Torus> *mem_ptr, void *bsk,
|
||||
Torus *ksk, uint32_t num_radix_blocks, uint32_t num_scalar_blocks) {
|
||||
|
||||
auto params = mem_ptr->params;
|
||||
auto big_lwe_dimension = params.big_lwe_dimension;
|
||||
auto message_modulus = params.message_modulus;
|
||||
|
||||
auto eq_buffer = mem_ptr->eq_buffer;
|
||||
|
||||
size_t big_lwe_size = big_lwe_dimension + 1;
|
||||
size_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus);
|
||||
|
||||
auto scalar_comparison_luts = eq_buffer->scalar_comparison_luts;
|
||||
|
||||
uint32_t num_halved_scalar_blocks =
|
||||
(num_scalar_blocks / 2) + (num_scalar_blocks % 2);
|
||||
|
||||
uint32_t num_lsb_radix_blocks =
|
||||
std::min(num_radix_blocks, 2 * num_halved_scalar_blocks);
|
||||
uint32_t num_msb_radix_blocks = num_radix_blocks - num_lsb_radix_blocks;
|
||||
uint32_t num_halved_lsb_radix_blocks =
|
||||
(num_lsb_radix_blocks / 2) + (num_lsb_radix_blocks % 2);
|
||||
|
||||
auto lsb = lwe_array_in;
|
||||
auto msb = lwe_array_in + big_lwe_size * num_lsb_radix_blocks;
|
||||
|
||||
auto lwe_array_lsb_out = mem_ptr->tmp_lwe_array_out;
|
||||
auto lwe_array_msb_out =
|
||||
lwe_array_lsb_out + big_lwe_size * num_halved_lsb_radix_blocks;
|
||||
|
||||
cuda_synchronize_stream(stream);
|
||||
|
||||
auto lsb_stream = mem_ptr->lsb_stream;
|
||||
auto msb_stream = mem_ptr->msb_stream;
|
||||
|
||||
#pragma omp parallel sections
|
||||
{
|
||||
// Both sections may be executed in parallel
|
||||
#pragma omp section
|
||||
{
|
||||
if (num_halved_scalar_blocks > 0) {
|
||||
auto packed_blocks = mem_ptr->tmp_packed_input;
|
||||
auto packed_scalar =
|
||||
packed_blocks + big_lwe_size * num_halved_lsb_radix_blocks;
|
||||
|
||||
pack_blocks(lsb_stream, packed_blocks, lsb, big_lwe_dimension,
|
||||
num_lsb_radix_blocks, message_modulus);
|
||||
pack_blocks(lsb_stream, packed_scalar, scalar_blocks, 0,
|
||||
num_scalar_blocks, message_modulus);
|
||||
|
||||
cuda_memcpy_async_gpu_to_gpu(
|
||||
scalar_comparison_luts->lut_indexes, packed_scalar,
|
||||
num_halved_scalar_blocks * sizeof(Torus), lsb_stream);
|
||||
|
||||
integer_radix_apply_univariate_lookup_table_kb(
|
||||
lsb_stream, lwe_array_lsb_out, packed_blocks, bsk, ksk,
|
||||
num_halved_lsb_radix_blocks, scalar_comparison_luts);
|
||||
}
|
||||
}
|
||||
#pragma omp section
|
||||
{
|
||||
//////////////
|
||||
// msb
|
||||
if (num_msb_radix_blocks > 0) {
|
||||
int_radix_lut<Torus> *msb_lut;
|
||||
switch (mem_ptr->op) {
|
||||
case COMPARISON_TYPE::EQ:
|
||||
msb_lut = mem_ptr->is_zero_lut;
|
||||
break;
|
||||
case COMPARISON_TYPE::NE:
|
||||
msb_lut = mem_ptr->eq_buffer->is_non_zero_lut;
|
||||
break;
|
||||
default:
|
||||
PANIC("Cuda error: integer operation not supported");
|
||||
}
|
||||
|
||||
host_compare_with_zero_equality(msb_stream, lwe_array_msb_out, msb,
|
||||
mem_ptr, bsk, ksk, num_msb_radix_blocks,
|
||||
msb_lut);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cuda_synchronize_stream(lsb_stream);
|
||||
cuda_synchronize_stream(msb_stream);
|
||||
|
||||
switch (mem_ptr->op) {
|
||||
case COMPARISON_TYPE::EQ:
|
||||
are_all_comparisons_block_true(
|
||||
stream, lwe_array_out, lwe_array_lsb_out, mem_ptr, bsk, ksk,
|
||||
num_halved_scalar_blocks + (num_msb_radix_blocks > 0));
|
||||
break;
|
||||
case COMPARISON_TYPE::NE:
|
||||
is_at_least_one_comparisons_block_true(
|
||||
stream, lwe_array_out, lwe_array_lsb_out, mem_ptr, bsk, ksk,
|
||||
num_halved_scalar_blocks + (num_msb_radix_blocks > 0));
|
||||
break;
|
||||
default:
|
||||
PANIC("Cuda error: integer operation not supported");
|
||||
}
|
||||
|
||||
// The result will be in the two first block. Everything else is
|
||||
// garbage.
|
||||
if (num_radix_blocks > 1)
|
||||
cuda_memset_async(lwe_array_out + big_lwe_size, 0,
|
||||
big_lwe_size_bytes * (num_radix_blocks - 1), stream);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -357,7 +357,7 @@ impl CudaServerKey {
|
||||
|
||||
/// Compares for equality 2 ciphertexts
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs == rhs, otherwise 0
|
||||
/// Returns a ciphertext containing 1 if lhs != rhs, otherwise 0
|
||||
///
|
||||
/// Requires carry bits to be empty
|
||||
///
|
||||
|
||||
@@ -120,6 +120,226 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_eq_async<T>(
|
||||
&self,
|
||||
ct: &CudaRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaRadixCiphertext
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::EQ, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_eq<T>(
|
||||
&self,
|
||||
ct: &CudaRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaRadixCiphertext
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_eq_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn scalar_eq_async<T>(
|
||||
&self,
|
||||
ct: &CudaRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaRadixCiphertext
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
ct
|
||||
} else {
|
||||
tmp_lhs = ct.duplicate_async(stream);
|
||||
self.full_propagate_assign_async(&mut tmp_lhs, stream);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_eq_async(lhs, scalar, stream)
|
||||
}
|
||||
|
||||
/// Compares for equality 2 ciphertexts
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs == rhs, otherwise 0
|
||||
///
|
||||
/// Requires carry bits to be empty
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
|
||||
/// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext;
|
||||
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
|
||||
/// use tfhe::integer::{gen_keys_radix, RadixCiphertext};
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
///
|
||||
/// let gpu_index = 0;
|
||||
/// let device = CudaDevice::new(gpu_index);
|
||||
/// let mut stream = CudaStream::new_unchecked(device);
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, size, &stream);
|
||||
///
|
||||
/// let msg1 = 14u64;
|
||||
/// let msg2 = 97u64;
|
||||
///
|
||||
/// let ct1 = cks.encrypt(msg1);
|
||||
///
|
||||
/// // Copy to GPU
|
||||
/// let mut d_ct1 = CudaRadixCiphertext::from_radix_ciphertext(&ct1, &stream);
|
||||
///
|
||||
/// let d_ct_res = sks.scalar_eq(&d_ct1, msg2, &stream);
|
||||
///
|
||||
/// // Copy the result back to CPU
|
||||
/// let ct_res: RadixCiphertext = d_ct_res.to_radix_ciphertext(&stream);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result: u64 = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 == msg2));
|
||||
/// ```
|
||||
pub fn scalar_eq<T>(
|
||||
&self,
|
||||
ct: &CudaRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaRadixCiphertext
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
{
|
||||
let result = unsafe { self.scalar_eq_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn scalar_ne_async<T>(
|
||||
&self,
|
||||
ct: &CudaRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaRadixCiphertext
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
ct
|
||||
} else {
|
||||
tmp_lhs = ct.duplicate_async(stream);
|
||||
self.full_propagate_assign_async(&mut tmp_lhs, stream);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_ne_async(lhs, scalar, stream)
|
||||
}
|
||||
|
||||
/// Compares for equality 2 ciphertexts
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs != rhs, otherwise 0
|
||||
///
|
||||
/// Requires carry bits to be empty
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
|
||||
/// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext;
|
||||
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
|
||||
/// use tfhe::integer::{gen_keys_radix, RadixCiphertext};
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
///
|
||||
/// let gpu_index = 0;
|
||||
/// let device = CudaDevice::new(gpu_index);
|
||||
/// let mut stream = CudaStream::new_unchecked(device);
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, size, &stream);
|
||||
///
|
||||
/// let msg1 = 14u64;
|
||||
/// let msg2 = 97u64;
|
||||
///
|
||||
/// let ct1 = cks.encrypt(msg1);
|
||||
///
|
||||
/// // Copy to GPU
|
||||
/// let mut d_ct1 = CudaRadixCiphertext::from_radix_ciphertext(&ct1, &stream);
|
||||
///
|
||||
/// let d_ct_res = sks.scalar_ne(&d_ct1, msg2, &stream);
|
||||
///
|
||||
/// // Copy the result back to CPU
|
||||
/// let ct_res: RadixCiphertext = d_ct_res.to_radix_ciphertext(&stream);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result: u64 = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 != msg2));
|
||||
/// ```
|
||||
pub fn scalar_ne<T>(
|
||||
&self,
|
||||
ct: &CudaRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaRadixCiphertext
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
{
|
||||
let result = unsafe { self.scalar_ne_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_ne_async<T>(
|
||||
&self,
|
||||
ct: &CudaRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaRadixCiphertext
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::NE, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_ne<T>(
|
||||
&self,
|
||||
ct: &CudaRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaRadixCiphertext
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_ne_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
@@ -277,6 +497,7 @@ impl CudaServerKey {
|
||||
stream.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
@@ -315,6 +536,7 @@ impl CudaServerKey {
|
||||
stream.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
|
||||
@@ -53,6 +53,8 @@ create_gpu_parametrized_test!(integer_unchecked_gt);
|
||||
create_gpu_parametrized_test!(integer_unchecked_ge);
|
||||
create_gpu_parametrized_test!(integer_unchecked_lt);
|
||||
create_gpu_parametrized_test!(integer_unchecked_le);
|
||||
create_gpu_parametrized_test!(integer_unchecked_scalar_eq);
|
||||
create_gpu_parametrized_test!(integer_unchecked_scalar_ne);
|
||||
create_gpu_parametrized_test!(integer_unchecked_scalar_gt);
|
||||
create_gpu_parametrized_test!(integer_unchecked_scalar_ge);
|
||||
create_gpu_parametrized_test!(integer_unchecked_scalar_lt);
|
||||
@@ -90,6 +92,8 @@ create_gpu_parametrized_test!(integer_gt);
|
||||
create_gpu_parametrized_test!(integer_ge);
|
||||
create_gpu_parametrized_test!(integer_lt);
|
||||
create_gpu_parametrized_test!(integer_le);
|
||||
create_gpu_parametrized_test!(integer_scalar_eq);
|
||||
create_gpu_parametrized_test!(integer_scalar_ne);
|
||||
create_gpu_parametrized_test!(integer_scalar_gt);
|
||||
create_gpu_parametrized_test!(integer_scalar_ge);
|
||||
create_gpu_parametrized_test!(integer_scalar_lt);
|
||||
@@ -653,7 +657,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -705,7 +708,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -727,7 +729,6 @@ where
|
||||
let d_ctxt_1 = CudaRadixCiphertext::from_radix_ciphertext(&ctxt_1, &stream);
|
||||
let d_ctxt_2 = CudaRadixCiphertext::from_radix_ciphertext(&ctxt_2, &stream);
|
||||
|
||||
// let h_ct_res = h_sks.unchecked_eq(&ctxt_1, &ctxt_2);
|
||||
let d_ct_res = sks.unchecked_ne(&d_ctxt_1, &d_ctxt_2, &stream);
|
||||
|
||||
let ct_res = d_ct_res.to_radix_ciphertext(&stream);
|
||||
@@ -757,7 +758,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -799,7 +799,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -841,7 +840,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -882,7 +880,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -916,6 +913,98 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_unchecked_scalar_eq<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters> + Copy,
|
||||
{
|
||||
let gpu_index = 0;
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64;
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
// Define the cleartexts
|
||||
let clear1 = rng.gen::<u64>() % modulus;
|
||||
let clear2 = rng.gen::<u64>() % modulus;
|
||||
|
||||
// Encrypt the integers;;
|
||||
let ctxt_1 = cks.encrypt_radix(clear1, NB_CTXT);
|
||||
|
||||
// Copy to the GPU
|
||||
let d_ctxt_1 = CudaRadixCiphertext::from_radix_ciphertext(&ctxt_1, &stream);
|
||||
|
||||
let d_ct_res = sks.unchecked_scalar_eq(&d_ctxt_1, clear2, &stream);
|
||||
|
||||
let ct_res = d_ct_res.to_radix_ciphertext(&stream);
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_res);
|
||||
let expected: u64 = (clear1 == clear2) as u64;
|
||||
|
||||
// Check the correctness
|
||||
assert_eq!(expected, dec_res);
|
||||
|
||||
let d_ct_res = sks.unchecked_scalar_eq(&d_ctxt_1, clear1, &stream);
|
||||
|
||||
let ct_res = d_ct_res.to_radix_ciphertext(&stream);
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_res);
|
||||
|
||||
// Check the correctness
|
||||
assert_eq!(1, dec_res);
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_unchecked_scalar_ne<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters> + Copy,
|
||||
{
|
||||
let gpu_index = 0;
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64;
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
// Define the cleartexts
|
||||
let clear1 = rng.gen::<u64>() % modulus;
|
||||
let clear2 = rng.gen::<u64>() % modulus;
|
||||
|
||||
// Encrypt the integers;;
|
||||
let ctxt_1 = cks.encrypt_radix(clear1, NB_CTXT);
|
||||
|
||||
// Copy to the GPU
|
||||
let d_ctxt_1 = CudaRadixCiphertext::from_radix_ciphertext(&ctxt_1, &stream);
|
||||
|
||||
let d_ct_res = sks.unchecked_scalar_ne(&d_ctxt_1, clear2, &stream);
|
||||
|
||||
let ct_res = d_ct_res.to_radix_ciphertext(&stream);
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_res);
|
||||
let expected: u64 = (clear1 != clear2) as u64;
|
||||
|
||||
// Check the correctness
|
||||
assert_eq!(expected, dec_res);
|
||||
|
||||
let d_ct_res = sks.unchecked_scalar_ne(&d_ctxt_1, clear1, &stream);
|
||||
|
||||
let ct_res = d_ct_res.to_radix_ciphertext(&stream);
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_res);
|
||||
|
||||
// Check the correctness
|
||||
assert_eq!(0, dec_res);
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_unchecked_scalar_gt<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters> + Copy,
|
||||
@@ -924,7 +1013,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -983,7 +1071,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1043,7 +1130,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1103,7 +1189,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1163,7 +1248,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let mut stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &mut stream);
|
||||
|
||||
//RNG
|
||||
@@ -1205,7 +1289,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let mut stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &mut stream);
|
||||
|
||||
//RNG
|
||||
@@ -1247,7 +1330,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let mut stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &mut stream);
|
||||
|
||||
//RNG
|
||||
@@ -1328,7 +1410,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let mut stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &mut stream);
|
||||
|
||||
//RNG
|
||||
@@ -1407,7 +1488,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1562,7 +1642,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1614,7 +1693,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1666,7 +1744,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1708,7 +1785,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1750,7 +1826,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1791,7 +1866,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1825,6 +1899,98 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_scalar_eq<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters> + Copy,
|
||||
{
|
||||
let gpu_index = 0;
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64;
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
// Define the cleartexts
|
||||
let clear1 = rng.gen::<u64>() % modulus;
|
||||
let clear2 = rng.gen::<u64>() % modulus;
|
||||
|
||||
// Encrypt the integers;;
|
||||
let ctxt_1 = cks.encrypt_radix(clear1, NB_CTXT);
|
||||
|
||||
// Copy to the GPU
|
||||
let d_ctxt_1 = CudaRadixCiphertext::from_radix_ciphertext(&ctxt_1, &stream);
|
||||
|
||||
let d_ct_res = sks.scalar_eq(&d_ctxt_1, clear2, &stream);
|
||||
|
||||
let ct_res = d_ct_res.to_radix_ciphertext(&stream);
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_res);
|
||||
let expected: u64 = (clear1 == clear2) as u64;
|
||||
|
||||
// Check the correctness
|
||||
assert_eq!(expected, dec_res);
|
||||
|
||||
let d_ct_res = sks.scalar_eq(&d_ctxt_1, clear1, &stream);
|
||||
|
||||
let ct_res = d_ct_res.to_radix_ciphertext(&stream);
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_res);
|
||||
|
||||
// Check the correctness
|
||||
assert_eq!(1, dec_res);
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_scalar_ne<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters> + Copy,
|
||||
{
|
||||
let gpu_index = 0;
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64;
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
// Define the cleartexts
|
||||
let clear1 = rng.gen::<u64>() % modulus;
|
||||
let clear2 = rng.gen::<u64>() % modulus;
|
||||
|
||||
// Encrypt the integers;;
|
||||
let ctxt_1 = cks.encrypt_radix(clear1, NB_CTXT);
|
||||
|
||||
// Copy to the GPU
|
||||
let d_ctxt_1 = CudaRadixCiphertext::from_radix_ciphertext(&ctxt_1, &stream);
|
||||
|
||||
let d_ct_res = sks.scalar_ne(&d_ctxt_1, clear2, &stream);
|
||||
|
||||
let ct_res = d_ct_res.to_radix_ciphertext(&stream);
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_res);
|
||||
let expected: u64 = (clear1 != clear2) as u64;
|
||||
|
||||
// Check the correctness
|
||||
assert_eq!(expected, dec_res);
|
||||
|
||||
let d_ct_res = sks.scalar_ne(&d_ctxt_1, clear1, &stream);
|
||||
|
||||
let ct_res = d_ct_res.to_radix_ciphertext(&stream);
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_res);
|
||||
|
||||
// Check the correctness
|
||||
assert_eq!(0, dec_res);
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_scalar_gt<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters> + Copy,
|
||||
@@ -1833,7 +1999,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1873,7 +2038,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1913,7 +2077,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -1953,7 +2116,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -2041,7 +2203,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -2084,7 +2245,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let mut stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &mut stream);
|
||||
|
||||
//RNG
|
||||
@@ -2126,7 +2286,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let mut stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &mut stream);
|
||||
|
||||
//RNG
|
||||
@@ -2168,7 +2327,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
@@ -2208,7 +2366,6 @@ where
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
// let (_, h_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
|
||||
//RNG
|
||||
|
||||
Reference in New Issue
Block a user