mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(gpu): implement signed scalar ge, gt, le, lt, max, and min
This commit is contained in:
@@ -36,7 +36,7 @@ enum COMPARISON_TYPE {
|
||||
MAX = 6,
|
||||
MIN = 7,
|
||||
};
|
||||
enum IS_RELATIONSHIP { IS_INFERIOR = 0, IS_EQUAL = 1, IS_SUPERIOR = 2 };
|
||||
enum CMP_ORDERING { IS_INFERIOR = 0, IS_EQUAL = 1, IS_SUPERIOR = 2 };
|
||||
|
||||
extern "C" {
|
||||
void scratch_cuda_full_propagation_64(
|
||||
@@ -1846,6 +1846,8 @@ template <typename Torus> struct int_tree_sign_reduction_buffer {
|
||||
bool allocate_gpu_memory) {
|
||||
this->params = params;
|
||||
|
||||
Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus);
|
||||
|
||||
block_selector_f = [](Torus msb, Torus lsb) -> Torus {
|
||||
if (msb == IS_EQUAL) // EQUAL
|
||||
return lsb;
|
||||
@@ -1854,13 +1856,8 @@ template <typename Torus> struct int_tree_sign_reduction_buffer {
|
||||
};
|
||||
|
||||
if (allocate_gpu_memory) {
|
||||
tmp_x = (Torus *)cuda_malloc_async((params.big_lwe_dimension + 1) *
|
||||
num_radix_blocks * sizeof(Torus),
|
||||
stream);
|
||||
tmp_y = (Torus *)cuda_malloc_async((params.big_lwe_dimension + 1) *
|
||||
num_radix_blocks * sizeof(Torus),
|
||||
stream);
|
||||
|
||||
tmp_x = (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream);
|
||||
tmp_y = (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream);
|
||||
// LUTs
|
||||
tree_inner_leaf_lut = new int_radix_lut<Torus>(
|
||||
stream, params, 1, num_radix_blocks, allocate_gpu_memory);
|
||||
@@ -1901,6 +1898,10 @@ template <typename Torus> struct int_comparison_diff_buffer {
|
||||
|
||||
int_tree_sign_reduction_buffer<Torus> *tree_buffer;
|
||||
|
||||
Torus *tmp_signs_a;
|
||||
Torus *tmp_signs_b;
|
||||
int_radix_lut<Torus> *reduce_signs_lut;
|
||||
|
||||
int_comparison_diff_buffer(cuda_stream_t *stream, COMPARISON_TYPE op,
|
||||
int_radix_params params, uint32_t num_radix_blocks,
|
||||
bool allocate_gpu_memory) {
|
||||
@@ -1922,7 +1923,6 @@ template <typename Torus> struct int_comparison_diff_buffer {
|
||||
return 42;
|
||||
}
|
||||
};
|
||||
|
||||
if (allocate_gpu_memory) {
|
||||
|
||||
Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus);
|
||||
@@ -1935,15 +1935,26 @@ template <typename Torus> struct int_comparison_diff_buffer {
|
||||
|
||||
tree_buffer = new int_tree_sign_reduction_buffer<Torus>(
|
||||
stream, operator_f, params, num_radix_blocks, allocate_gpu_memory);
|
||||
tmp_signs_a =
|
||||
(Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream);
|
||||
tmp_signs_b =
|
||||
(Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream);
|
||||
// LUTs
|
||||
reduce_signs_lut = new int_radix_lut<Torus>(
|
||||
stream, params, 1, num_radix_blocks, allocate_gpu_memory);
|
||||
}
|
||||
}
|
||||
|
||||
void release(cuda_stream_t *stream) {
|
||||
tree_buffer->release(stream);
|
||||
delete tree_buffer;
|
||||
reduce_signs_lut->release(stream);
|
||||
delete reduce_signs_lut;
|
||||
|
||||
cuda_drop_async(tmp_packed_left, stream);
|
||||
cuda_drop_async(tmp_packed_right, stream);
|
||||
cuda_drop_async(tmp_signs_a, stream);
|
||||
cuda_drop_async(tmp_signs_b, stream);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1963,6 +1974,7 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
|
||||
Torus *tmp_block_comparisons;
|
||||
Torus *tmp_lwe_array_out;
|
||||
Torus *tmp_trivial_sign_block;
|
||||
|
||||
// Scalar EQ / NE
|
||||
Torus *tmp_packed_input;
|
||||
@@ -1975,6 +1987,7 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
bool is_signed;
|
||||
|
||||
// Used for scalar comparisons
|
||||
int_radix_lut<Torus> *signed_msb_lut;
|
||||
cuda_stream_t *lsb_stream;
|
||||
cuda_stream_t *msb_stream;
|
||||
|
||||
@@ -1987,22 +2000,22 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
|
||||
identity_lut_f = [](Torus x) -> Torus { return x; };
|
||||
|
||||
auto big_lwe_size = params.big_lwe_dimension + 1;
|
||||
|
||||
if (allocate_gpu_memory) {
|
||||
lsb_stream = cuda_create_stream(stream->gpu_index);
|
||||
msb_stream = cuda_create_stream(stream->gpu_index);
|
||||
|
||||
// +1 to have space for signed comparison
|
||||
tmp_lwe_array_out = (Torus *)cuda_malloc_async(
|
||||
(params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus),
|
||||
stream);
|
||||
big_lwe_size * (num_radix_blocks + 1) * sizeof(Torus), stream);
|
||||
|
||||
tmp_packed_input = (Torus *)cuda_malloc_async(
|
||||
(params.big_lwe_dimension + 1) * 2 * num_radix_blocks * sizeof(Torus),
|
||||
stream);
|
||||
big_lwe_size * 2 * num_radix_blocks * sizeof(Torus), stream);
|
||||
|
||||
// Block comparisons
|
||||
tmp_block_comparisons = (Torus *)cuda_malloc_async(
|
||||
(params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus),
|
||||
stream);
|
||||
big_lwe_size * num_radix_blocks * sizeof(Torus), stream);
|
||||
|
||||
// Cleaning LUT
|
||||
identity_lut = new int_radix_lut<Torus>(
|
||||
@@ -2054,13 +2067,19 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
}
|
||||
|
||||
if (is_signed) {
|
||||
|
||||
tmp_trivial_sign_block =
|
||||
(Torus *)cuda_malloc_async(big_lwe_size * sizeof(Torus), stream);
|
||||
|
||||
signed_lut =
|
||||
new int_radix_lut<Torus>(stream, params, 1, 1, allocate_gpu_memory);
|
||||
signed_msb_lut =
|
||||
new int_radix_lut<Torus>(stream, params, 1, 1, allocate_gpu_memory);
|
||||
|
||||
auto message_modulus = (int)params.message_modulus;
|
||||
uint32_t sign_bit_pos = log2(message_modulus) - 1;
|
||||
std::function<Torus(Torus, Torus)> signed_lut_f;
|
||||
signed_lut_f = [sign_bit_pos](Torus x, Torus y) -> Torus {
|
||||
std::function<Torus(Torus, Torus)> signed_lut_f =
|
||||
[sign_bit_pos](Torus x, Torus y) -> Torus {
|
||||
auto x_sign_bit = x >> sign_bit_pos;
|
||||
auto y_sign_bit = y >> sign_bit_pos;
|
||||
|
||||
@@ -2076,14 +2095,14 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
return (Torus)(IS_INFERIOR);
|
||||
else if (x == y)
|
||||
return (Torus)(IS_EQUAL);
|
||||
else if (x > y)
|
||||
else
|
||||
return (Torus)(IS_SUPERIOR);
|
||||
} else {
|
||||
if (x < y)
|
||||
return (Torus)(IS_SUPERIOR);
|
||||
else if (x == y)
|
||||
return (Torus)(IS_EQUAL);
|
||||
else if (x > y)
|
||||
else
|
||||
return (Torus)(IS_INFERIOR);
|
||||
}
|
||||
PANIC("Cuda error: sign_lut creation failed due to wrong function.")
|
||||
@@ -2126,8 +2145,11 @@ template <typename Torus> struct int_comparison_buffer {
|
||||
cuda_drop_async(tmp_packed_input, stream);
|
||||
|
||||
if (is_signed) {
|
||||
cuda_drop_async(tmp_trivial_sign_block, stream);
|
||||
signed_lut->release(stream);
|
||||
delete (signed_lut);
|
||||
signed_msb_lut->release(stream);
|
||||
delete (signed_msb_lut);
|
||||
}
|
||||
cuda_destroy_stream(lsb_stream);
|
||||
cuda_destroy_stream(msb_stream);
|
||||
|
||||
@@ -273,7 +273,7 @@ __host__ void host_compare_with_zero_equality(
|
||||
remainder_blocks -= (chunk_size - 1);
|
||||
|
||||
// Update operands
|
||||
chunk += chunk_size * big_lwe_size;
|
||||
chunk += (chunk_size - 1) * big_lwe_size;
|
||||
sum_i += big_lwe_size;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -587,7 +587,7 @@ __global__ void device_pack_blocks(Torus *lwe_array_out, Torus *lwe_array_in,
|
||||
packed_block[tid] = lsb_block[tid] + factor * msb_block[tid];
|
||||
}
|
||||
|
||||
if (num_radix_blocks % 2 != 0) {
|
||||
if (num_radix_blocks % 2 == 1) {
|
||||
// We couldn't pack the last block, so we just copy it
|
||||
Torus *lsb_block =
|
||||
lwe_array_in + (num_radix_blocks - 1) * (lwe_dimension + 1);
|
||||
@@ -684,4 +684,91 @@ __host__ void extract_n_bits(cuda_stream_t *stream, Torus *lwe_array_out,
|
||||
num_radix_blocks * bits_per_block, bit_extract->lut);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void reduce_signs(cuda_stream_t *stream, Torus *signs_array_out,
|
||||
Torus *signs_array_in,
|
||||
int_comparison_buffer<Torus> *mem_ptr,
|
||||
std::function<Torus(Torus)> sign_handler_f,
|
||||
void *bsk, Torus *ksk, uint32_t num_sign_blocks) {
|
||||
|
||||
auto diff_buffer = mem_ptr->diff_buffer;
|
||||
|
||||
auto params = mem_ptr->params;
|
||||
auto big_lwe_dimension = params.big_lwe_dimension;
|
||||
auto glwe_dimension = params.glwe_dimension;
|
||||
auto polynomial_size = params.polynomial_size;
|
||||
auto message_modulus = params.message_modulus;
|
||||
auto carry_modulus = params.carry_modulus;
|
||||
|
||||
std::function<Torus(Torus)> reduce_two_orderings_function =
|
||||
[diff_buffer, sign_handler_f](Torus x) -> Torus {
|
||||
int msb = (x >> 2) & 3;
|
||||
int lsb = x & 3;
|
||||
|
||||
return diff_buffer->tree_buffer->block_selector_f(msb, lsb);
|
||||
};
|
||||
|
||||
auto signs_a = diff_buffer->tmp_signs_a;
|
||||
auto signs_b = diff_buffer->tmp_signs_b;
|
||||
|
||||
cuda_memcpy_async_gpu_to_gpu(
|
||||
signs_a, signs_array_in,
|
||||
(big_lwe_dimension + 1) * num_sign_blocks * sizeof(Torus), stream);
|
||||
if (num_sign_blocks > 2) {
|
||||
auto lut = diff_buffer->reduce_signs_lut;
|
||||
generate_device_accumulator<Torus>(
|
||||
stream, lut->lut, glwe_dimension, polynomial_size, message_modulus,
|
||||
carry_modulus, reduce_two_orderings_function);
|
||||
|
||||
while (num_sign_blocks > 2) {
|
||||
pack_blocks(stream, signs_b, signs_a, big_lwe_dimension, num_sign_blocks,
|
||||
4);
|
||||
integer_radix_apply_univariate_lookup_table_kb(
|
||||
stream, signs_a, signs_b, bsk, ksk, num_sign_blocks / 2, lut);
|
||||
|
||||
auto last_block_signs_b =
|
||||
signs_b + (num_sign_blocks / 2) * (big_lwe_dimension + 1);
|
||||
auto last_block_signs_a =
|
||||
signs_a + (num_sign_blocks / 2) * (big_lwe_dimension + 1);
|
||||
if (num_sign_blocks % 2 == 1)
|
||||
cuda_memcpy_async_gpu_to_gpu(last_block_signs_a, last_block_signs_b,
|
||||
(big_lwe_dimension + 1) * sizeof(Torus),
|
||||
stream);
|
||||
|
||||
num_sign_blocks = (num_sign_blocks / 2) + (num_sign_blocks % 2);
|
||||
}
|
||||
}
|
||||
|
||||
if (num_sign_blocks == 2) {
|
||||
std::function<Torus(Torus)> final_lut_f =
|
||||
[reduce_two_orderings_function, sign_handler_f](Torus x) -> Torus {
|
||||
Torus final_sign = reduce_two_orderings_function(x);
|
||||
return sign_handler_f(final_sign);
|
||||
};
|
||||
|
||||
auto lut = diff_buffer->reduce_signs_lut;
|
||||
generate_device_accumulator<Torus>(stream, lut->lut, glwe_dimension,
|
||||
polynomial_size, message_modulus,
|
||||
carry_modulus, final_lut_f);
|
||||
|
||||
pack_blocks(stream, signs_b, signs_a, big_lwe_dimension, 2, 4);
|
||||
integer_radix_apply_univariate_lookup_table_kb(stream, signs_array_out,
|
||||
signs_b, bsk, ksk, 1, lut);
|
||||
|
||||
} else {
|
||||
|
||||
std::function<Torus(Torus)> final_lut_f =
|
||||
[mem_ptr, sign_handler_f](Torus x) -> Torus {
|
||||
return sign_handler_f(x & 3);
|
||||
};
|
||||
|
||||
auto lut = mem_ptr->diff_buffer->reduce_signs_lut;
|
||||
generate_device_accumulator<Torus>(stream, lut->lut, glwe_dimension,
|
||||
polynomial_size, message_modulus,
|
||||
carry_modulus, final_lut_f);
|
||||
|
||||
integer_radix_apply_univariate_lookup_table_kb(stream, signs_array_out,
|
||||
signs_a, bsk, ksk, 1, lut);
|
||||
}
|
||||
}
|
||||
#endif // TFHE_RS_INTERNAL_INTEGER_CUH
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <omp.h>
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void host_integer_radix_scalar_difference_check_kb(
|
||||
__host__ void integer_radix_unsigned_scalar_difference_check_kb(
|
||||
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
|
||||
Torus *scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
|
||||
std::function<Torus(Torus)> sign_handler_f, void *bsk, Torus *ksk,
|
||||
@@ -184,6 +184,344 @@ __host__ void host_integer_radix_scalar_difference_check_kb(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void integer_radix_signed_scalar_difference_check_kb(
|
||||
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
|
||||
Torus *scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
|
||||
std::function<Torus(Torus)> sign_handler_f, void *bsk, Torus *ksk,
|
||||
uint32_t total_num_radix_blocks, uint32_t total_num_scalar_blocks) {
|
||||
|
||||
cudaSetDevice(stream->gpu_index);
|
||||
auto params = mem_ptr->params;
|
||||
auto big_lwe_dimension = params.big_lwe_dimension;
|
||||
auto glwe_dimension = params.glwe_dimension;
|
||||
auto polynomial_size = params.polynomial_size;
|
||||
auto message_modulus = params.message_modulus;
|
||||
auto carry_modulus = params.carry_modulus;
|
||||
|
||||
auto diff_buffer = mem_ptr->diff_buffer;
|
||||
|
||||
size_t big_lwe_size = big_lwe_dimension + 1;
|
||||
|
||||
// Reducing the signs is the bottleneck of the comparison algorithms,
|
||||
// however if the scalar case there is an improvement:
|
||||
//
|
||||
// The idea is to reduce the number of signs block we have to
|
||||
// reduce. We can do that by splitting the comparison problem in two parts.
|
||||
//
|
||||
// - One part where we compute the signs block between the scalar with just
|
||||
// enough blocks
|
||||
// from the ciphertext that can represent the scalar value
|
||||
//
|
||||
// - The other part is to compare the ciphertext blocks not considered for the
|
||||
// sign
|
||||
// computation with zero, and create a single sign block from that.
|
||||
//
|
||||
// The smaller the scalar value is compared to the ciphertext num bits
|
||||
// encrypted, the more the comparisons with zeros we have to do, and the less
|
||||
// signs block we will have to reduce.
|
||||
//
|
||||
// This will create a speedup as comparing a bunch of blocks with 0
|
||||
// is faster
|
||||
if (total_num_scalar_blocks == 0) {
|
||||
// We only have to compare blocks with zero
|
||||
// means scalar is zero
|
||||
Torus *are_all_msb_zeros = mem_ptr->tmp_lwe_array_out;
|
||||
host_compare_with_zero_equality(stream, are_all_msb_zeros, lwe_array_in,
|
||||
mem_ptr, bsk, ksk, total_num_radix_blocks,
|
||||
mem_ptr->is_zero_lut);
|
||||
Torus *sign_block =
|
||||
lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size;
|
||||
|
||||
auto sign_bit_pos = (int)std::log2(message_modulus) - 1;
|
||||
|
||||
auto scalar_last_leaf_with_respect_to_zero_lut_f =
|
||||
[sign_handler_f, sign_bit_pos,
|
||||
message_modulus](Torus sign_block) -> Torus {
|
||||
sign_block %= message_modulus;
|
||||
int sign_bit_is_set = (sign_block >> sign_bit_pos) == 1;
|
||||
CMP_ORDERING sign_block_ordering;
|
||||
if (sign_bit_is_set) {
|
||||
sign_block_ordering = CMP_ORDERING::IS_INFERIOR;
|
||||
} else if (sign_block != 0) {
|
||||
sign_block_ordering = CMP_ORDERING::IS_SUPERIOR;
|
||||
} else {
|
||||
sign_block_ordering = CMP_ORDERING::IS_EQUAL;
|
||||
}
|
||||
|
||||
return sign_block_ordering;
|
||||
};
|
||||
|
||||
auto block_selector_f = mem_ptr->diff_buffer->tree_buffer->block_selector_f;
|
||||
auto scalar_bivariate_last_leaf_lut_f =
|
||||
[scalar_last_leaf_with_respect_to_zero_lut_f, sign_handler_f,
|
||||
block_selector_f](Torus are_all_zeros, Torus sign_block) -> Torus {
|
||||
// "re-code" are_all_zeros as an ordering value
|
||||
if (are_all_zeros == 1) {
|
||||
are_all_zeros = CMP_ORDERING::IS_EQUAL;
|
||||
} else {
|
||||
are_all_zeros = CMP_ORDERING::IS_SUPERIOR;
|
||||
};
|
||||
|
||||
return sign_handler_f(block_selector_f(
|
||||
scalar_last_leaf_with_respect_to_zero_lut_f(sign_block),
|
||||
are_all_zeros));
|
||||
};
|
||||
|
||||
auto lut = mem_ptr->diff_buffer->tree_buffer->tree_last_leaf_scalar_lut;
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
stream, lut->lut, glwe_dimension, polynomial_size, message_modulus,
|
||||
carry_modulus, scalar_bivariate_last_leaf_lut_f);
|
||||
|
||||
integer_radix_apply_bivariate_lookup_table_kb(
|
||||
stream, lwe_array_out, are_all_msb_zeros, sign_block, bsk, ksk, 1, lut);
|
||||
|
||||
} else if (total_num_scalar_blocks < total_num_radix_blocks) {
|
||||
// We have to handle both part of the work described above
|
||||
// And the sign bit is located in the most_significant_blocks
|
||||
|
||||
uint32_t num_lsb_radix_blocks = total_num_scalar_blocks;
|
||||
uint32_t num_msb_radix_blocks =
|
||||
total_num_radix_blocks - num_lsb_radix_blocks;
|
||||
auto msb = lwe_array_in + num_lsb_radix_blocks * big_lwe_size;
|
||||
|
||||
auto lwe_array_lsb_out = mem_ptr->tmp_lwe_array_out;
|
||||
auto lwe_array_msb_out = lwe_array_lsb_out + big_lwe_size;
|
||||
|
||||
cuda_synchronize_stream(stream);
|
||||
auto lsb_stream = mem_ptr->lsb_stream;
|
||||
auto msb_stream = mem_ptr->msb_stream;
|
||||
|
||||
#pragma omp parallel sections
|
||||
{
|
||||
// Both sections may be executed in parallel
|
||||
#pragma omp section
|
||||
{
|
||||
//////////////
|
||||
// lsb
|
||||
Torus *lhs = diff_buffer->tmp_packed_left;
|
||||
Torus *rhs = diff_buffer->tmp_packed_right;
|
||||
|
||||
pack_blocks(lsb_stream, lhs, lwe_array_in, big_lwe_dimension,
|
||||
num_lsb_radix_blocks, message_modulus);
|
||||
pack_blocks(lsb_stream, rhs, scalar_blocks, 0, total_num_scalar_blocks,
|
||||
message_modulus);
|
||||
|
||||
// From this point we have half number of blocks
|
||||
num_lsb_radix_blocks /= 2;
|
||||
num_lsb_radix_blocks += (total_num_scalar_blocks % 2);
|
||||
|
||||
// comparisons will be assigned
|
||||
// - 0 if lhs < rhs
|
||||
// - 1 if lhs == rhs
|
||||
// - 2 if lhs > rhs
|
||||
|
||||
auto comparisons = mem_ptr->tmp_block_comparisons;
|
||||
scalar_compare_radix_blocks_kb(lsb_stream, comparisons, lhs, rhs,
|
||||
mem_ptr, bsk, ksk, num_lsb_radix_blocks);
|
||||
|
||||
// Reduces a vec containing radix blocks that encrypts a sign
|
||||
// (inferior, equal, superior) to one single radix block containing the
|
||||
// final sign
|
||||
tree_sign_reduction(lsb_stream, lwe_array_lsb_out, comparisons,
|
||||
mem_ptr->diff_buffer->tree_buffer,
|
||||
mem_ptr->identity_lut_f, bsk, ksk,
|
||||
num_lsb_radix_blocks);
|
||||
}
|
||||
#pragma omp section
|
||||
{
|
||||
//////////////
|
||||
// msb
|
||||
// We remove the last block (which is the sign)
|
||||
Torus *are_all_msb_zeros = lwe_array_msb_out;
|
||||
host_compare_with_zero_equality(msb_stream, are_all_msb_zeros, msb,
|
||||
mem_ptr, bsk, ksk, num_msb_radix_blocks,
|
||||
mem_ptr->is_zero_lut);
|
||||
|
||||
auto sign_bit_pos = (int)log2(message_modulus) - 1;
|
||||
|
||||
auto lut_f = [mem_ptr, sign_bit_pos](Torus sign_block,
|
||||
Torus msb_are_zeros) {
|
||||
bool sign_bit_is_set = (sign_block >> sign_bit_pos) == 1;
|
||||
CMP_ORDERING sign_block_ordering;
|
||||
if (sign_bit_is_set) {
|
||||
sign_block_ordering = CMP_ORDERING::IS_INFERIOR;
|
||||
} else if (sign_block != 0) {
|
||||
sign_block_ordering = CMP_ORDERING::IS_SUPERIOR;
|
||||
} else {
|
||||
sign_block_ordering = CMP_ORDERING::IS_EQUAL;
|
||||
}
|
||||
|
||||
CMP_ORDERING msb_ordering;
|
||||
if (msb_are_zeros == 1)
|
||||
msb_ordering = CMP_ORDERING::IS_EQUAL;
|
||||
else
|
||||
msb_ordering = CMP_ORDERING::IS_SUPERIOR;
|
||||
|
||||
return mem_ptr->diff_buffer->tree_buffer->block_selector_f(
|
||||
sign_block_ordering, msb_ordering);
|
||||
};
|
||||
|
||||
auto signed_msb_lut = mem_ptr->signed_msb_lut;
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
msb_stream, signed_msb_lut->lut, params.glwe_dimension,
|
||||
params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, lut_f);
|
||||
|
||||
Torus *sign_block = msb + (num_msb_radix_blocks - 1) * big_lwe_size;
|
||||
integer_radix_apply_bivariate_lookup_table_kb(
|
||||
msb_stream, lwe_array_msb_out, sign_block, are_all_msb_zeros, bsk,
|
||||
ksk, 1, signed_msb_lut);
|
||||
}
|
||||
}
|
||||
cuda_synchronize_stream(lsb_stream);
|
||||
cuda_synchronize_stream(msb_stream);
|
||||
|
||||
//////////////
|
||||
// Reduce the two blocks into one final
|
||||
reduce_signs(stream, lwe_array_out, lwe_array_lsb_out, mem_ptr,
|
||||
sign_handler_f, bsk, ksk, 2);
|
||||
|
||||
} else {
|
||||
// We only have to do the regular comparison
|
||||
// And not the part where we compare most significant blocks with zeros
|
||||
// total_num_radix_blocks == total_num_scalar_blocks
|
||||
uint32_t num_lsb_radix_blocks = total_num_radix_blocks;
|
||||
|
||||
cuda_synchronize_stream(stream);
|
||||
auto lsb_stream = mem_ptr->lsb_stream;
|
||||
auto msb_stream = mem_ptr->msb_stream;
|
||||
|
||||
auto lwe_array_ct_out = mem_ptr->tmp_lwe_array_out;
|
||||
auto lwe_array_sign_out =
|
||||
lwe_array_ct_out + (num_lsb_radix_blocks / 2) * big_lwe_size;
|
||||
#pragma omp parallel sections
|
||||
{
|
||||
// Both sections may be executed in parallel
|
||||
#pragma omp section
|
||||
{
|
||||
Torus *lhs = diff_buffer->tmp_packed_left;
|
||||
Torus *rhs = diff_buffer->tmp_packed_right;
|
||||
|
||||
pack_blocks(lsb_stream, lhs, lwe_array_in, big_lwe_dimension,
|
||||
num_lsb_radix_blocks - 1, message_modulus);
|
||||
pack_blocks(lsb_stream, rhs, scalar_blocks, 0, num_lsb_radix_blocks - 1,
|
||||
message_modulus);
|
||||
|
||||
// From this point we have half number of blocks
|
||||
num_lsb_radix_blocks /= 2;
|
||||
|
||||
// comparisons will be assigned
|
||||
// - 0 if lhs < rhs
|
||||
// - 1 if lhs == rhs
|
||||
// - 2 if lhs > rhs
|
||||
scalar_compare_radix_blocks_kb(lsb_stream, lwe_array_ct_out, lhs, rhs,
|
||||
mem_ptr, bsk, ksk, num_lsb_radix_blocks);
|
||||
}
|
||||
#pragma omp section
|
||||
{
|
||||
Torus *encrypted_sign_block =
|
||||
lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size;
|
||||
Torus *scalar_sign_block =
|
||||
scalar_blocks + (total_num_scalar_blocks - 1);
|
||||
|
||||
auto trivial_sign_block = mem_ptr->tmp_trivial_sign_block;
|
||||
create_trivial_radix(msb_stream, trivial_sign_block, scalar_sign_block,
|
||||
big_lwe_dimension, 1, 1, message_modulus,
|
||||
carry_modulus);
|
||||
|
||||
integer_radix_apply_bivariate_lookup_table_kb(
|
||||
msb_stream, lwe_array_sign_out, encrypted_sign_block,
|
||||
trivial_sign_block, bsk, ksk, 1, mem_ptr->signed_lut);
|
||||
}
|
||||
}
|
||||
cuda_synchronize_stream(lsb_stream);
|
||||
cuda_synchronize_stream(msb_stream);
|
||||
|
||||
// Reduces a vec containing radix blocks that encrypts a sign
|
||||
// (inferior, equal, superior) to one single radix block containing the
|
||||
// final sign
|
||||
reduce_signs(stream, lwe_array_out, lwe_array_ct_out, mem_ptr,
|
||||
sign_handler_f, bsk, ksk, num_lsb_radix_blocks + 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void integer_radix_signed_scalar_maxmin_kb(
|
||||
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
|
||||
Torus *scalar_blocks, int_comparison_buffer<Torus> *mem_ptr, void *bsk,
|
||||
Torus *ksk, uint32_t total_num_radix_blocks,
|
||||
uint32_t total_num_scalar_blocks) {
|
||||
|
||||
cudaSetDevice(stream->gpu_index);
|
||||
auto params = mem_ptr->params;
|
||||
// Calculates the difference sign between the ciphertext and the scalar
|
||||
// - 0 if lhs < rhs
|
||||
// - 1 if lhs == rhs
|
||||
// - 2 if lhs > rhs
|
||||
auto sign = mem_ptr->tmp_lwe_array_out;
|
||||
integer_radix_signed_scalar_difference_check_kb(
|
||||
stream, sign, lwe_array_in, scalar_blocks, mem_ptr,
|
||||
mem_ptr->identity_lut_f, bsk, ksk, total_num_radix_blocks,
|
||||
total_num_scalar_blocks);
|
||||
|
||||
// There is no optimized CMUX for scalars, so we convert to a trivial
|
||||
// ciphertext
|
||||
auto lwe_array_left = lwe_array_in;
|
||||
auto lwe_array_right = mem_ptr->tmp_block_comparisons;
|
||||
|
||||
create_trivial_radix(stream, lwe_array_right, scalar_blocks,
|
||||
params.big_lwe_dimension, total_num_radix_blocks,
|
||||
total_num_scalar_blocks, params.message_modulus,
|
||||
params.carry_modulus);
|
||||
|
||||
// Selector
|
||||
// CMUX for Max or Min
|
||||
host_integer_radix_cmux_kb(stream, lwe_array_out, sign, lwe_array_left,
|
||||
lwe_array_right, mem_ptr->cmux_buffer, bsk, ksk,
|
||||
total_num_radix_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void host_integer_radix_scalar_difference_check_kb(
|
||||
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
|
||||
Torus *scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
|
||||
std::function<Torus(Torus)> sign_handler_f, void *bsk, Torus *ksk,
|
||||
uint32_t total_num_radix_blocks, uint32_t total_num_scalar_blocks) {
|
||||
|
||||
if (mem_ptr->is_signed) {
|
||||
// is signed and scalar is positive
|
||||
integer_radix_signed_scalar_difference_check_kb(
|
||||
stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr,
|
||||
sign_handler_f, bsk, ksk, total_num_radix_blocks,
|
||||
total_num_scalar_blocks);
|
||||
} else {
|
||||
integer_radix_unsigned_scalar_difference_check_kb(
|
||||
stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr,
|
||||
sign_handler_f, bsk, ksk, total_num_radix_blocks,
|
||||
total_num_scalar_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void host_integer_radix_signed_scalar_maxmin_kb(
|
||||
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
|
||||
Torus *scalar_blocks, int_comparison_buffer<Torus> *mem_ptr, void *bsk,
|
||||
Torus *ksk, uint32_t total_num_radix_blocks,
|
||||
uint32_t total_num_scalar_blocks) {
|
||||
|
||||
if (mem_ptr->is_signed) {
|
||||
// is signed and scalar is positive
|
||||
integer_radix_signed_scalar_maxmin_kb(
|
||||
stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, bsk, ksk,
|
||||
total_num_radix_blocks, total_num_scalar_blocks);
|
||||
} else {
|
||||
integer_radix_unsigned_scalar_maxmin_kb(
|
||||
stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, bsk, ksk,
|
||||
total_num_radix_blocks, total_num_scalar_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void
|
||||
scalar_compare_radix_blocks_kb(cuda_stream_t *stream, Torus *lwe_array_out,
|
||||
|
||||
@@ -1921,6 +1921,8 @@ mod cuda {
|
||||
cuda_unchecked_scalar_right_shift,
|
||||
cuda_unchecked_scalar_rotate_left,
|
||||
cuda_unchecked_scalar_rotate_right,
|
||||
cuda_unchecked_scalar_eq,
|
||||
cuda_unchecked_scalar_ne,
|
||||
cuda_unchecked_scalar_ge,
|
||||
cuda_unchecked_scalar_gt,
|
||||
cuda_unchecked_scalar_le,
|
||||
@@ -1964,6 +1966,8 @@ mod cuda {
|
||||
cuda_scalar_bitand,
|
||||
cuda_scalar_bitor,
|
||||
cuda_scalar_bitxor,
|
||||
cuda_scalar_eq,
|
||||
cuda_scalar_ne,
|
||||
cuda_scalar_ge,
|
||||
cuda_scalar_gt,
|
||||
cuda_scalar_le,
|
||||
|
||||
@@ -1795,6 +1795,54 @@ mod cuda {
|
||||
rng_func: shift_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: unchecked_scalar_eq,
|
||||
display_name: eq,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: unchecked_scalar_ne,
|
||||
display_name: ne,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: unchecked_scalar_gt,
|
||||
display_name: gt,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: unchecked_scalar_ge,
|
||||
display_name: ge,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: unchecked_scalar_lt,
|
||||
display_name: lt,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: unchecked_scalar_le,
|
||||
display_name: le,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: unchecked_scalar_min,
|
||||
display_name: min,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: unchecked_scalar_max,
|
||||
display_name: max,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
//===========================================
|
||||
// Default
|
||||
//===========================================
|
||||
@@ -1959,6 +2007,54 @@ mod cuda {
|
||||
rng_func: shift_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: scalar_eq,
|
||||
display_name: eq,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: scalar_ne,
|
||||
display_name: ne,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: scalar_gt,
|
||||
display_name: gt,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: scalar_ge,
|
||||
display_name: ge,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: scalar_lt,
|
||||
display_name: lt,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: scalar_le,
|
||||
display_name: le,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: scalar_min,
|
||||
display_name: min,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: scalar_max,
|
||||
display_name: max,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
unchecked_cuda_ops,
|
||||
cuda_unchecked_add,
|
||||
@@ -1995,6 +2091,14 @@ mod cuda {
|
||||
cuda_unchecked_scalar_right_shift,
|
||||
cuda_unchecked_scalar_rotate_left,
|
||||
cuda_unchecked_scalar_rotate_right,
|
||||
cuda_unchecked_scalar_eq,
|
||||
cuda_unchecked_scalar_ne,
|
||||
cuda_unchecked_scalar_gt,
|
||||
cuda_unchecked_scalar_ge,
|
||||
cuda_unchecked_scalar_lt,
|
||||
cuda_unchecked_scalar_le,
|
||||
cuda_unchecked_scalar_min,
|
||||
cuda_unchecked_scalar_max,
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
@@ -2034,6 +2138,14 @@ mod cuda {
|
||||
cuda_scalar_right_shift,
|
||||
cuda_scalar_rotate_left,
|
||||
cuda_scalar_rotate_right,
|
||||
cuda_scalar_eq,
|
||||
cuda_scalar_ne,
|
||||
cuda_scalar_gt,
|
||||
cuda_scalar_ge,
|
||||
cuda_scalar_lt,
|
||||
cuda_scalar_le,
|
||||
cuda_scalar_min,
|
||||
cuda_scalar_max,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -337,7 +337,7 @@ impl FheEq<bool> for FheBool {
|
||||
let inner =
|
||||
cuda_key
|
||||
.key
|
||||
.scalar_eq(&self.ciphertext.on_gpu(), u8::from(other), stream);
|
||||
.scalar_eq(&*self.ciphertext.on_gpu(), u8::from(other), stream);
|
||||
InnerBoolean::Cuda(inner)
|
||||
}),
|
||||
});
|
||||
@@ -376,7 +376,7 @@ impl FheEq<bool> for FheBool {
|
||||
let inner =
|
||||
cuda_key
|
||||
.key
|
||||
.scalar_ne(&self.ciphertext.on_gpu(), u8::from(other), stream);
|
||||
.scalar_ne(&*self.ciphertext.on_gpu(), u8::from(other), stream);
|
||||
InnerBoolean::Cuda(inner)
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -6,6 +6,8 @@ use crate::high_level_api::global_state::with_thread_local_cuda_stream;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::BooleanBlock;
|
||||
use crate::prelude::{FheDecrypt, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt};
|
||||
use crate::shortint::ciphertext::Degree;
|
||||
@@ -104,11 +106,12 @@ impl FheTryTrivialEncrypt<bool> for FheBool {
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
|
||||
let inner = cuda_key
|
||||
.key
|
||||
.create_trivial_radix(u64::from(value), 1, stream);
|
||||
let inner: CudaUnsignedRadixCiphertext =
|
||||
cuda_key
|
||||
.key
|
||||
.create_trivial_radix(u64::from(value), 1, stream);
|
||||
InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext(
|
||||
inner.ciphertext,
|
||||
inner.into_inner(),
|
||||
))
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -5,6 +5,8 @@ use crate::high_level_api::global_state::with_thread_local_cuda_stream;
|
||||
use crate::high_level_api::integers::FheUintId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom};
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
use crate::prelude::{FheDecrypt, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt};
|
||||
use crate::{ClientKey, CompactPublicKey, CompressedPublicKey, FheUint, PublicKey};
|
||||
|
||||
@@ -133,7 +135,7 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
|
||||
let inner = cuda_key.key.create_trivial_radix(
|
||||
let inner: CudaUnsignedRadixCiphertext = cuda_key.key.create_trivial_radix(
|
||||
value,
|
||||
Id::num_blocks(cuda_key.key.message_modulus),
|
||||
stream,
|
||||
|
||||
@@ -17,6 +17,8 @@ use crate::high_level_api::traits::{
|
||||
};
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
use crate::integer::ciphertext::IntegerCiphertext;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
use crate::integer::U256;
|
||||
use crate::FheBool;
|
||||
use std::ops::{
|
||||
@@ -60,7 +62,7 @@ where
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.scalar_eq(&self.ciphertext.on_gpu(), rhs, stream);
|
||||
.scalar_eq(&*self.ciphertext.on_gpu(), rhs, stream);
|
||||
FheBool::new(inner_result)
|
||||
}),
|
||||
})
|
||||
@@ -97,7 +99,7 @@ where
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.scalar_ne(&self.ciphertext.on_gpu(), rhs, stream);
|
||||
.scalar_ne(&*self.ciphertext.on_gpu(), rhs, stream);
|
||||
FheBool::new(inner_result)
|
||||
}),
|
||||
})
|
||||
@@ -140,7 +142,7 @@ where
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.scalar_lt(&self.ciphertext.on_gpu(), rhs, stream);
|
||||
.scalar_lt(&*self.ciphertext.on_gpu(), rhs, stream);
|
||||
FheBool::new(inner_result)
|
||||
}),
|
||||
})
|
||||
@@ -177,7 +179,7 @@ where
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.scalar_le(&self.ciphertext.on_gpu(), rhs, stream);
|
||||
.scalar_le(&*self.ciphertext.on_gpu(), rhs, stream);
|
||||
FheBool::new(inner_result)
|
||||
}),
|
||||
})
|
||||
@@ -214,7 +216,7 @@ where
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.scalar_gt(&self.ciphertext.on_gpu(), rhs, stream);
|
||||
.scalar_gt(&*self.ciphertext.on_gpu(), rhs, stream);
|
||||
FheBool::new(inner_result)
|
||||
}),
|
||||
})
|
||||
@@ -251,7 +253,7 @@ where
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.scalar_ge(&self.ciphertext.on_gpu(), rhs, stream);
|
||||
.scalar_ge(&*self.ciphertext.on_gpu(), rhs, stream);
|
||||
FheBool::new(inner_result)
|
||||
}),
|
||||
})
|
||||
@@ -296,7 +298,7 @@ where
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.scalar_max(&self.ciphertext.on_gpu(), rhs, stream);
|
||||
.scalar_max(&*self.ciphertext.on_gpu(), rhs, stream);
|
||||
Self::new(inner_result)
|
||||
}),
|
||||
})
|
||||
@@ -341,7 +343,7 @@ where
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.scalar_min(&self.ciphertext.on_gpu(), rhs, stream);
|
||||
.scalar_min(&*self.ciphertext.on_gpu(), rhs, stream);
|
||||
Self::new(inner_result)
|
||||
}),
|
||||
})
|
||||
@@ -1036,7 +1038,8 @@ generic_integer_impl_scalar_left_operation!(
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_stream(|stream| {
|
||||
let mut result = cuda_key.key.create_trivial_radix(lhs, rhs.ciphertext.on_gpu().ciphertext.info.blocks.len(), stream);
|
||||
let mut result: CudaUnsignedRadixCiphertext = cuda_key.key.create_trivial_radix(
|
||||
lhs, rhs.ciphertext.on_gpu().ciphertext.info.blocks.len(), stream);
|
||||
cuda_key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(), stream);
|
||||
RadixCiphertext::Cuda(result)
|
||||
})
|
||||
|
||||
@@ -19,6 +19,8 @@ pub trait CudaIntegerRadixCiphertext: Sized {
|
||||
Self::from(self.as_ref().duplicate(stream))
|
||||
}
|
||||
|
||||
fn into_inner(self) -> CudaRadixCiphertext;
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
@@ -67,6 +69,10 @@ impl CudaIntegerRadixCiphertext for CudaUnsignedRadixCiphertext {
|
||||
fn from(ct: CudaRadixCiphertext) -> Self {
|
||||
Self { ciphertext: ct }
|
||||
}
|
||||
|
||||
fn into_inner(self) -> CudaRadixCiphertext {
|
||||
self.ciphertext
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaIntegerRadixCiphertext for CudaSignedRadixCiphertext {
|
||||
@@ -83,6 +89,10 @@ impl CudaIntegerRadixCiphertext for CudaSignedRadixCiphertext {
|
||||
fn from(ct: CudaRadixCiphertext) -> Self {
|
||||
Self { ciphertext: ct }
|
||||
}
|
||||
|
||||
fn into_inner(self) -> CudaRadixCiphertext {
|
||||
self.ciphertext
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaUnsignedRadixCiphertext {
|
||||
|
||||
@@ -1142,7 +1142,7 @@ impl CudaStream {
|
||||
num_blocks: u32,
|
||||
num_scalar_blocks: u32,
|
||||
op: ComparisonType,
|
||||
is_signed: bool,
|
||||
signed_with_positive_scalar: bool,
|
||||
) {
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
scratch_cuda_integer_radix_comparison_kb_64(
|
||||
@@ -1162,7 +1162,7 @@ impl CudaStream {
|
||||
carry_modulus.0 as u32,
|
||||
PBSType::Classical as u32,
|
||||
op as u32,
|
||||
is_signed,
|
||||
signed_with_positive_scalar,
|
||||
true,
|
||||
);
|
||||
|
||||
@@ -1209,7 +1209,7 @@ impl CudaStream {
|
||||
num_blocks: u32,
|
||||
num_scalar_blocks: u32,
|
||||
op: ComparisonType,
|
||||
is_signed: bool,
|
||||
signed_with_positive_scalar: bool,
|
||||
) {
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
scratch_cuda_integer_radix_comparison_kb_64(
|
||||
@@ -1229,7 +1229,7 @@ impl CudaStream {
|
||||
carry_modulus.0 as u32,
|
||||
PBSType::MultiBit as u32,
|
||||
op as u32,
|
||||
is_signed,
|
||||
signed_with_positive_scalar,
|
||||
true,
|
||||
);
|
||||
cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
|
||||
|
||||
@@ -71,7 +71,7 @@ impl CudaServerKey {
|
||||
num_blocks: usize,
|
||||
stream: &CudaStream,
|
||||
) -> T {
|
||||
T::from(self.create_trivial_radix(0, num_blocks, stream).ciphertext)
|
||||
self.create_trivial_radix(0, num_blocks, stream)
|
||||
}
|
||||
|
||||
/// Create a trivial ciphertext on the GPU
|
||||
@@ -80,6 +80,7 @@ impl CudaServerKey {
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
|
||||
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
|
||||
/// use tfhe::integer::{gen_keys_radix, RadixCiphertext};
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
@@ -93,20 +94,22 @@ impl CudaServerKey {
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream);
|
||||
///
|
||||
/// let d_ctxt = sks.create_trivial_radix(212u64, num_blocks, &mut stream);
|
||||
/// let d_ctxt: CudaUnsignedRadixCiphertext =
|
||||
/// sks.create_trivial_radix(212u64, num_blocks, &mut stream);
|
||||
/// let ctxt = d_ctxt.to_radix_ciphertext(&mut stream);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec: u64 = cks.decrypt(&ctxt);
|
||||
/// assert_eq!(212, dec);
|
||||
/// ```
|
||||
pub fn create_trivial_radix<Scalar>(
|
||||
pub fn create_trivial_radix<Scalar, T>(
|
||||
&self,
|
||||
scalar: Scalar,
|
||||
num_blocks: usize,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let lwe_size = match self.pbs_order {
|
||||
@@ -140,12 +143,10 @@ impl CudaServerKey {
|
||||
|
||||
let d_blocks = CudaLweCiphertextList::from_lwe_ciphertext_list(&cpu_lwe_list, stream);
|
||||
|
||||
CudaUnsignedRadixCiphertext {
|
||||
ciphertext: CudaRadixCiphertext {
|
||||
d_blocks,
|
||||
info: CudaRadixCiphertextInfo { blocks: info },
|
||||
},
|
||||
}
|
||||
T::from(CudaRadixCiphertext {
|
||||
d_blocks,
|
||||
info: CudaRadixCiphertextInfo { blocks: info },
|
||||
})
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
@@ -268,7 +269,7 @@ impl CudaServerKey {
|
||||
///
|
||||
///```rust
|
||||
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
|
||||
/// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext;
|
||||
/// use tfhe::integer::gpu::ciphertext::{CudaRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
|
||||
/// use tfhe::integer::IntegerCiphertext;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
@@ -282,7 +283,8 @@ impl CudaServerKey {
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream);
|
||||
///
|
||||
/// let mut d_ct1 = sks.create_trivial_radix(7u64, num_blocks, &mut stream);
|
||||
/// let mut d_ct1: CudaUnsignedRadixCiphertext =
|
||||
/// sks.create_trivial_radix(7u64, num_blocks, &mut stream);
|
||||
/// let ct1 = d_ct1.to_radix_ciphertext(&mut stream);
|
||||
/// assert_eq!(ct1.blocks().len(), 4);
|
||||
///
|
||||
@@ -342,6 +344,7 @@ impl CudaServerKey {
|
||||
///
|
||||
///```rust
|
||||
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
|
||||
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
|
||||
/// use tfhe::integer::IntegerCiphertext;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
@@ -355,7 +358,8 @@ impl CudaServerKey {
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream);
|
||||
///
|
||||
/// let mut d_ct1 = sks.create_trivial_radix(7u64, num_blocks, &mut stream);
|
||||
/// let mut d_ct1: CudaUnsignedRadixCiphertext =
|
||||
/// sks.create_trivial_radix(7u64, num_blocks, &mut stream);
|
||||
/// let ct1 = d_ct1.to_radix_ciphertext(&mut stream);
|
||||
/// assert_eq!(ct1.blocks().len(), 4);
|
||||
///
|
||||
@@ -406,6 +410,7 @@ impl CudaServerKey {
|
||||
///
|
||||
///```rust
|
||||
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
|
||||
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
|
||||
/// use tfhe::integer::IntegerCiphertext;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
@@ -419,7 +424,8 @@ impl CudaServerKey {
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream);
|
||||
///
|
||||
/// let mut d_ct1 = sks.create_trivial_radix(119u64, num_blocks, &mut stream);
|
||||
/// let mut d_ct1: CudaUnsignedRadixCiphertext =
|
||||
/// sks.create_trivial_radix(119u64, num_blocks, &mut stream);
|
||||
/// let ct1 = d_ct1.to_radix_ciphertext(&mut stream);
|
||||
/// assert_eq!(ct1.blocks().len(), 4);
|
||||
///
|
||||
@@ -470,6 +476,7 @@ impl CudaServerKey {
|
||||
///
|
||||
///```rust
|
||||
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
|
||||
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
|
||||
/// use tfhe::integer::IntegerCiphertext;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
@@ -483,7 +490,8 @@ impl CudaServerKey {
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream);
|
||||
///
|
||||
/// let mut d_ct1 = sks.create_trivial_radix(119u64, num_blocks, &mut stream);
|
||||
/// let mut d_ct1: CudaUnsignedRadixCiphertext =
|
||||
/// sks.create_trivial_radix(119u64, num_blocks, &mut stream);
|
||||
/// let ct1 = d_ct1.to_radix_ciphertext(&mut stream);
|
||||
/// assert_eq!(ct1.blocks().len(), 4);
|
||||
///
|
||||
@@ -532,6 +540,7 @@ impl CudaServerKey {
|
||||
///
|
||||
///```rust
|
||||
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
|
||||
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
|
||||
/// use tfhe::integer::IntegerCiphertext;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
@@ -546,7 +555,8 @@ impl CudaServerKey {
|
||||
///
|
||||
/// let msg = 2u8;
|
||||
///
|
||||
/// let mut d_ct1 = sks.create_trivial_radix(msg, num_blocks, &mut stream);
|
||||
/// let mut d_ct1: CudaUnsignedRadixCiphertext =
|
||||
/// sks.create_trivial_radix(msg, num_blocks, &mut stream);
|
||||
/// let ct1 = d_ct1.to_radix_ciphertext(&mut stream);
|
||||
/// assert_eq!(ct1.blocks().len(), 4);
|
||||
///
|
||||
|
||||
@@ -5,36 +5,124 @@ use crate::core_crypto::prelude::{CiphertextModulus, LweCiphertextCount};
|
||||
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::info::CudaRadixCiphertextInfo;
|
||||
use crate::integer::gpu::ciphertext::{
|
||||
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaUnsignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::ComparisonType;
|
||||
use crate::integer::server_key::comparator::Comparator;
|
||||
use crate::shortint::ciphertext::Degree;
|
||||
|
||||
impl CudaServerKey {
|
||||
/// Returns whether the clear scalar is outside of the
|
||||
/// value range the ciphertext can hold.
|
||||
///
|
||||
/// - Returns None if the scalar is in the range of values that the ciphertext can represent
|
||||
///
|
||||
/// - Returns Some(ordering) when the scalar is out of representable range of the ciphertext.
|
||||
/// - Equal will never be returned
|
||||
/// - Less means the scalar is less than the min value representable by the ciphertext
|
||||
/// - Greater means the scalar is greater that the max value representable by the ciphertext
|
||||
pub(crate) fn is_scalar_out_of_bounds<T, Scalar>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
) -> Option<std::cmp::Ordering>
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let scalar_blocks =
|
||||
BlockDecomposer::with_early_stop_at_zero(scalar, self.message_modulus.0.ilog2())
|
||||
.iter_as::<u64>()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let ct_len = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
if T::IS_SIGNED {
|
||||
let sign_bit_pos = self.message_modulus.0.ilog2() - 1;
|
||||
let sign_bit_is_set = scalar_blocks
|
||||
.get(ct_len.0 - 1)
|
||||
.map_or(false, |block| (block >> sign_bit_pos) == 1);
|
||||
|
||||
if scalar > Scalar::ZERO
|
||||
&& (scalar_blocks.len() > ct_len.0
|
||||
|| (scalar_blocks.len() == ct_len.0 && sign_bit_is_set))
|
||||
{
|
||||
// If scalar is positive and that any bits above the ct's n-1 bits is set
|
||||
// it means scalar is bigger.
|
||||
//
|
||||
// This is checked in two step
|
||||
// - If there a more scalar blocks than ct blocks then ct is trivially bigger
|
||||
// - If there are the same number of blocks but the "sign bit" / msb of st scalar is
|
||||
// set then, the scalar is trivially bigger
|
||||
return Some(std::cmp::Ordering::Greater);
|
||||
} else if scalar < Scalar::ZERO {
|
||||
// If scalar is negative, and that any bits above the ct's n-1 bits is not set
|
||||
// it means scalar is smaller.
|
||||
|
||||
if ct_len.0 > scalar_blocks.len() {
|
||||
// Ciphertext has more blocks, the scalar may be in range
|
||||
return None;
|
||||
}
|
||||
|
||||
// (returns false for empty iter)
|
||||
let at_least_one_block_is_not_full_of_1s = scalar_blocks[ct_len.0..]
|
||||
.iter()
|
||||
.any(|&scalar_block| scalar_block != (self.message_modulus.0 as u64 - 1));
|
||||
|
||||
let sign_bit_pos = self.message_modulus.0.ilog2() - 1;
|
||||
let sign_bit_is_unset = scalar_blocks
|
||||
.get(ct_len.0 - 1)
|
||||
.map_or(false, |block| (block >> sign_bit_pos) == 0);
|
||||
|
||||
if at_least_one_block_is_not_full_of_1s || sign_bit_is_unset {
|
||||
// Scalar is smaller than lowest value of T
|
||||
return Some(std::cmp::Ordering::Less);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// T is unsigned
|
||||
if scalar < Scalar::ZERO {
|
||||
// ct represent an unsigned (always >= 0)
|
||||
return Some(std::cmp::Ordering::Less);
|
||||
} else if scalar > Scalar::ZERO {
|
||||
// scalar is obviously bigger if it has non-zero
|
||||
// blocks after lhs's last block
|
||||
let is_scalar_obviously_bigger =
|
||||
scalar_blocks.get(ct_len.0..).is_some_and(|sub_slice| {
|
||||
sub_slice.iter().any(|&scalar_block| scalar_block != 0)
|
||||
});
|
||||
if is_scalar_obviously_bigger {
|
||||
return Some(std::cmp::Ordering::Greater);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_comparison_async<T>(
|
||||
pub unsafe fn unchecked_signed_and_unsigned_scalar_comparison_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
op: ComparisonType,
|
||||
signed_with_positive_scalar: bool,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
if scalar < T::ZERO {
|
||||
if scalar < Scalar::ZERO {
|
||||
// ct represents an unsigned (always >= 0)
|
||||
let ct_res = self.create_trivial_radix(Comparator::IS_SUPERIOR, 1, stream);
|
||||
return CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(
|
||||
ct_res.ciphertext.d_blocks,
|
||||
ct_res.ciphertext.info,
|
||||
));
|
||||
let value = match op {
|
||||
ComparisonType::GT | ComparisonType::GE => 1,
|
||||
_ => 0,
|
||||
};
|
||||
let ct_res: T = self.create_trivial_radix(value, 1, stream);
|
||||
return CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner());
|
||||
}
|
||||
|
||||
let message_modulus = self.message_modulus.0;
|
||||
@@ -50,11 +138,12 @@ impl CudaServerKey {
|
||||
.get(ct.as_ref().d_blocks.lwe_ciphertext_count().0..)
|
||||
.is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0));
|
||||
if is_scalar_obviously_bigger {
|
||||
let ct_res = self.create_trivial_radix(Comparator::IS_INFERIOR, 1, stream);
|
||||
return CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(
|
||||
ct_res.ciphertext.d_blocks,
|
||||
ct_res.ciphertext.info,
|
||||
));
|
||||
let value = match op {
|
||||
ComparisonType::LT | ComparisonType::LE => 1,
|
||||
_ => 0,
|
||||
};
|
||||
let ct_res: T = self.create_trivial_radix(value, 1, stream);
|
||||
return CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner());
|
||||
}
|
||||
|
||||
// If we are still here, that means scalar_blocks above
|
||||
@@ -105,7 +194,7 @@ impl CudaServerKey {
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
false,
|
||||
signed_with_positive_scalar,
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
@@ -133,7 +222,7 @@ impl CudaServerKey {
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
false,
|
||||
signed_with_positive_scalar,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -145,49 +234,97 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_minmax_async<Scalar>(
|
||||
pub unsafe fn unchecked_scalar_comparison_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
op: ComparisonType,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
if scalar < Scalar::ZERO {
|
||||
// ct represents an unsigned (always >= 0)
|
||||
return self.create_trivial_radix(Comparator::IS_SUPERIOR, 1, stream);
|
||||
}
|
||||
let num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0;
|
||||
|
||||
if T::IS_SIGNED {
|
||||
match self.is_scalar_out_of_bounds(ct, scalar) {
|
||||
Some(std::cmp::Ordering::Greater) => {
|
||||
// Scalar is greater than the bounds, so ciphertext is smaller
|
||||
let result: T = match op {
|
||||
ComparisonType::LT | ComparisonType::LE => {
|
||||
self.create_trivial_radix(1, num_blocks, stream)
|
||||
}
|
||||
_ => self.create_trivial_radix(
|
||||
0,
|
||||
ct.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
stream,
|
||||
),
|
||||
};
|
||||
return CudaBooleanBlock::from_cuda_radix_ciphertext(result.into_inner());
|
||||
}
|
||||
Some(std::cmp::Ordering::Less) => {
|
||||
// Scalar is smaller than the bounds, so ciphertext is bigger
|
||||
let result: T = match op {
|
||||
ComparisonType::GT | ComparisonType::GE => {
|
||||
self.create_trivial_radix(1, num_blocks, stream)
|
||||
}
|
||||
_ => self.create_trivial_radix(
|
||||
0,
|
||||
ct.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
stream,
|
||||
),
|
||||
};
|
||||
return CudaBooleanBlock::from_cuda_radix_ciphertext(result.into_inner());
|
||||
}
|
||||
Some(std::cmp::Ordering::Equal) => unreachable!("Internal error: invalid value"),
|
||||
None => {
|
||||
// scalar is in range, fallthrough
|
||||
}
|
||||
}
|
||||
|
||||
if scalar >= Scalar::ZERO {
|
||||
self.unchecked_signed_and_unsigned_scalar_comparison_async(
|
||||
ct, scalar, op, true, stream,
|
||||
)
|
||||
} else {
|
||||
let scalar_as_trivial = self.create_trivial_radix(scalar, num_blocks, stream);
|
||||
self.unchecked_comparison_async(ct, &scalar_as_trivial, op, stream)
|
||||
}
|
||||
} else {
|
||||
// Unsigned
|
||||
self.unchecked_signed_and_unsigned_scalar_comparison_async(
|
||||
ct, scalar, op, false, stream,
|
||||
)
|
||||
}
|
||||
}
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_minmax_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
op: ComparisonType,
|
||||
stream: &CudaStream,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let message_modulus = self.message_modulus.0;
|
||||
|
||||
let mut scalar_blocks =
|
||||
let scalar_blocks =
|
||||
BlockDecomposer::with_early_stop_at_zero(scalar, message_modulus.ilog2())
|
||||
.iter_as::<u64>()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// scalar is obviously bigger if it has non-zero
|
||||
// blocks after lhs's last block
|
||||
let is_scalar_obviously_bigger = scalar_blocks
|
||||
.get(ct.as_ref().d_blocks.lwe_ciphertext_count().0..)
|
||||
.is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0));
|
||||
if is_scalar_obviously_bigger {
|
||||
return self.create_trivial_radix(Comparator::IS_INFERIOR, 1, stream);
|
||||
}
|
||||
|
||||
// If we are still here, that means scalar_blocks above
|
||||
// num_blocks are 0s, we can remove them
|
||||
// as we will handle them separately.
|
||||
scalar_blocks.truncate(ct.as_ref().d_blocks.lwe_ciphertext_count().0);
|
||||
|
||||
let d_scalar_blocks: CudaVec<u64> = CudaVec::from_cpu_async(&scalar_blocks, stream);
|
||||
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let mut result = CudaUnsignedRadixCiphertext {
|
||||
ciphertext: ct.as_ref().duplicate_async(stream),
|
||||
};
|
||||
let mut result = ct.duplicate_async(stream);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -214,7 +351,7 @@ impl CudaServerKey {
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
false,
|
||||
T::IS_SIGNED,
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
@@ -242,7 +379,7 @@ impl CudaServerKey {
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
false,
|
||||
T::IS_SIGNED,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -254,26 +391,28 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_eq_async<Scalar>(
|
||||
pub unsafe fn unchecked_scalar_eq_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::EQ, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_eq<T>(
|
||||
pub fn unchecked_scalar_eq<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_eq_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -284,14 +423,15 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn scalar_eq_async<T>(
|
||||
pub unsafe fn scalar_eq_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
@@ -346,14 +486,15 @@ impl CudaServerKey {
|
||||
/// let dec_result = cks.decrypt_bool(&ct_res);
|
||||
/// assert_eq!(dec_result, msg1 == msg2);
|
||||
/// ```
|
||||
pub fn scalar_eq<T>(
|
||||
pub fn scalar_eq<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let result = unsafe { self.scalar_eq_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -364,14 +505,15 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn scalar_ne_async<T>(
|
||||
pub unsafe fn scalar_ne_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
@@ -426,14 +568,15 @@ impl CudaServerKey {
|
||||
/// let dec_result = cks.decrypt_bool(&ct_res);
|
||||
/// assert_eq!(dec_result, msg1 != msg2);
|
||||
/// ```
|
||||
pub fn scalar_ne<T>(
|
||||
pub fn scalar_ne<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_ne_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -444,26 +587,28 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_ne_async<T>(
|
||||
pub unsafe fn unchecked_scalar_ne_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::NE, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_ne<T>(
|
||||
pub fn unchecked_scalar_ne<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_ne_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -474,26 +619,28 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_gt_async<T>(
|
||||
pub unsafe fn unchecked_scalar_gt_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::GT, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_gt<T>(
|
||||
pub fn unchecked_scalar_gt<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_gt_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -504,26 +651,28 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_ge_async<T>(
|
||||
pub unsafe fn unchecked_scalar_ge_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::GE, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_ge<T>(
|
||||
pub fn unchecked_scalar_ge<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_ge_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -534,26 +683,28 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_lt_async<T>(
|
||||
pub unsafe fn unchecked_scalar_lt_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::LT, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_lt<T>(
|
||||
pub fn unchecked_scalar_lt<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_lt_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -564,26 +715,28 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_le_async<T>(
|
||||
pub unsafe fn unchecked_scalar_le_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::LE, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_le<T>(
|
||||
pub fn unchecked_scalar_le<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_le_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -593,14 +746,15 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn scalar_gt_async<T>(
|
||||
pub unsafe fn scalar_gt_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
@@ -614,14 +768,15 @@ impl CudaServerKey {
|
||||
self.unchecked_scalar_gt_async(lhs, scalar, stream)
|
||||
}
|
||||
|
||||
pub fn scalar_gt<T>(
|
||||
pub fn scalar_gt<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_gt_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -632,14 +787,15 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn scalar_ge_async<T>(
|
||||
pub unsafe fn scalar_ge_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
@@ -653,14 +809,15 @@ impl CudaServerKey {
|
||||
self.unchecked_scalar_ge_async(lhs, scalar, stream)
|
||||
}
|
||||
|
||||
pub fn scalar_ge<T>(
|
||||
pub fn scalar_ge<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_ge_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -671,14 +828,15 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn scalar_lt_async<T>(
|
||||
pub unsafe fn scalar_lt_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
@@ -692,14 +850,15 @@ impl CudaServerKey {
|
||||
self.unchecked_scalar_lt_async(lhs, scalar, stream)
|
||||
}
|
||||
|
||||
pub fn scalar_lt<T>(
|
||||
pub fn scalar_lt<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_lt_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -709,14 +868,15 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn scalar_le_async<T>(
|
||||
pub unsafe fn scalar_le_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
@@ -730,14 +890,15 @@ impl CudaServerKey {
|
||||
self.unchecked_scalar_le_async(lhs, scalar, stream)
|
||||
}
|
||||
|
||||
pub fn scalar_le<T>(
|
||||
pub fn scalar_le<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_le_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -748,26 +909,23 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_max_async<T>(
|
||||
pub unsafe fn unchecked_scalar_max_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
) -> T
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_minmax_async(ct, scalar, ComparisonType::MAX, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_max<T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
pub fn unchecked_scalar_max<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_max_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -778,26 +936,23 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_min_async<Scalar>(
|
||||
pub unsafe fn unchecked_scalar_min_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
) -> T
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_minmax_async(ct, scalar, ComparisonType::MIN, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_min<Scalar>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
pub fn unchecked_scalar_min<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_min_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -808,14 +963,15 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn scalar_max_async<T>(
|
||||
pub unsafe fn scalar_max_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
) -> T
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
@@ -829,14 +985,10 @@ impl CudaServerKey {
|
||||
self.unchecked_scalar_max_async(lhs, scalar, stream)
|
||||
}
|
||||
|
||||
pub fn scalar_max<T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
pub fn scalar_max<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_max_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
@@ -847,14 +999,15 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn scalar_min_async<T>(
|
||||
pub unsafe fn scalar_min_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
) -> T
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
@@ -868,14 +1021,10 @@ impl CudaServerKey {
|
||||
self.unchecked_scalar_min_async(lhs, scalar, stream)
|
||||
}
|
||||
|
||||
pub fn scalar_min<T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
pub fn scalar_min<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
|
||||
where
|
||||
T: DecomposableInto<u64>,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_min_async(ct, scalar, stream) };
|
||||
stream.synchronize();
|
||||
|
||||
@@ -7,6 +7,7 @@ pub(crate) mod test_neg;
|
||||
pub(crate) mod test_rotate;
|
||||
pub(crate) mod test_scalar_add;
|
||||
pub(crate) mod test_scalar_bitwise_op;
|
||||
pub(crate) mod test_scalar_comparison;
|
||||
pub(crate) mod test_scalar_mul;
|
||||
pub(crate) mod test_scalar_rotate;
|
||||
pub(crate) mod test_scalar_shift;
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
create_gpu_parametrized_test, GpuFunctionExecutor,
|
||||
};
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_comparison::{
|
||||
test_signed_default_scalar_function, test_signed_default_scalar_minmax,
|
||||
test_signed_unchecked_scalar_function, test_signed_unchecked_scalar_minmax,
|
||||
};
|
||||
use crate::shortint::parameters::*;
|
||||
|
||||
/// This macro generates the tests for a given comparison fn
|
||||
///
|
||||
/// All our comparison function have 2 variants:
|
||||
/// - unchecked_$comparison_name
|
||||
/// - $comparison_name
|
||||
///
|
||||
/// So, for example, for the `gt` comparison fn, this macro will generate the tests for
|
||||
/// the 2 variants described above
|
||||
macro_rules! define_gpu_signed_scalar_comparison_test_functions {
|
||||
($comparison_name:ident, $clear_type:ty) => {
|
||||
::paste::paste!{
|
||||
fn [<integer_signed_unchecked_scalar_ $comparison_name _ $clear_type>]<P>(param: P) where P: Into<PBSParameters> {
|
||||
let num_tests = 1;
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::[<unchecked_scalar_ $comparison_name>]);
|
||||
test_signed_unchecked_scalar_function(
|
||||
param,
|
||||
num_tests,
|
||||
executor,
|
||||
|lhs, rhs| $clear_type::from(<$clear_type>::$comparison_name(&lhs, &rhs)),
|
||||
)
|
||||
}
|
||||
|
||||
fn [<integer_signed_default_scalar_ $comparison_name $clear_type>]<P>(param: P) where P: Into<PBSParameters> {
|
||||
let num_tests = 10;
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::[<scalar_ $comparison_name>]);
|
||||
test_signed_default_scalar_function(
|
||||
param,
|
||||
num_tests,
|
||||
executor,
|
||||
|lhs, rhs| $clear_type::from(<$clear_type>::$comparison_name(&lhs, &rhs)),
|
||||
)
|
||||
}
|
||||
|
||||
create_gpu_parametrized_test!([<integer_signed_unchecked_scalar_ $comparison_name _ $clear_type>]);
|
||||
create_gpu_parametrized_test!([<integer_signed_default_scalar_ $comparison_name $clear_type>]);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn integer_signed_unchecked_scalar_min_i128<P>(params: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_min);
|
||||
test_signed_unchecked_scalar_minmax(params, 2, executor, std::cmp::min::<i128>);
|
||||
}
|
||||
|
||||
fn integer_signed_unchecked_scalar_max_i128<P>(params: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_max);
|
||||
test_signed_unchecked_scalar_minmax(params, 2, executor, std::cmp::max::<i128>);
|
||||
}
|
||||
|
||||
fn integer_signed_scalar_min_i128<P>(params: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_min);
|
||||
test_signed_default_scalar_minmax(params, 2, executor, std::cmp::min::<i128>);
|
||||
}
|
||||
|
||||
fn integer_signed_scalar_max_i128<P>(params: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_max);
|
||||
test_signed_default_scalar_minmax(params, 2, executor, std::cmp::max::<i128>);
|
||||
}
|
||||
|
||||
create_gpu_parametrized_test!(integer_signed_unchecked_scalar_max_i128);
|
||||
create_gpu_parametrized_test!(integer_signed_unchecked_scalar_min_i128);
|
||||
create_gpu_parametrized_test!(integer_signed_scalar_max_i128);
|
||||
create_gpu_parametrized_test!(integer_signed_scalar_min_i128);
|
||||
|
||||
define_gpu_signed_scalar_comparison_test_functions!(eq, i128);
|
||||
define_gpu_signed_scalar_comparison_test_functions!(ne, i128);
|
||||
define_gpu_signed_scalar_comparison_test_functions!(lt, i128);
|
||||
define_gpu_signed_scalar_comparison_test_functions!(le, i128);
|
||||
define_gpu_signed_scalar_comparison_test_functions!(gt, i128);
|
||||
define_gpu_signed_scalar_comparison_test_functions!(ge, i128);
|
||||
@@ -79,6 +79,137 @@ where
|
||||
test_default_scalar_minmax(params, 2, executor, std::cmp::max::<U256>);
|
||||
}
|
||||
|
||||
// The goal of this function is to ensure that scalar comparisons
|
||||
// work when the scalar type used is either bigger or smaller (in bit size)
|
||||
// compared to the ciphertext
|
||||
//fn integer_unchecked_scalar_comparisons_edge(param: ClassicPBSParameters) {
|
||||
// let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
//
|
||||
// let gpu_index = 0;
|
||||
// let device = CudaDevice::new(gpu_index);
|
||||
// let stream = CudaStream::new_unchecked(device);
|
||||
//
|
||||
// let (cks, sks) = gen_keys_gpu(param, &stream);
|
||||
//
|
||||
// let mut rng = rand::thread_rng();
|
||||
//
|
||||
// for _ in 0..4 {
|
||||
// let clear_a = rng.gen_range((u128::from(u64::MAX) + 1)..=u128::MAX);
|
||||
// let smaller_clear = rng.gen::<u64>();
|
||||
// let bigger_clear = rng.gen::<U256>();
|
||||
//
|
||||
// let a = cks.encrypt_radix(clear_a, num_block);
|
||||
// // Copy to the GPU
|
||||
// let d_a = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&a, &stream);
|
||||
//
|
||||
// // >=
|
||||
// {
|
||||
// let d_result = sks.unchecked_scalar_ge(&d_a, smaller_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) >= U256::from(smaller_clear));
|
||||
//
|
||||
// let d_result = sks.unchecked_scalar_ge(&d_a, bigger_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) >= bigger_clear);
|
||||
// }
|
||||
//
|
||||
// // >
|
||||
// {
|
||||
// let d_result = sks.unchecked_scalar_gt(&d_a, smaller_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) > U256::from(smaller_clear));
|
||||
//
|
||||
// let d_result = sks.unchecked_scalar_gt(&d_a, bigger_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) > bigger_clear);
|
||||
// }
|
||||
//
|
||||
// // <=
|
||||
// {
|
||||
// let d_result = sks.unchecked_scalar_le(&d_a, smaller_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) <= U256::from(smaller_clear));
|
||||
//
|
||||
// let d_result = sks.unchecked_scalar_le(&d_a, bigger_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) <= bigger_clear);
|
||||
// }
|
||||
//
|
||||
// // <
|
||||
// {
|
||||
// let d_result = sks.unchecked_scalar_lt(&d_a, smaller_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) < U256::from(smaller_clear));
|
||||
//
|
||||
// let d_result = sks.unchecked_scalar_lt(&d_a, bigger_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) < bigger_clear);
|
||||
// }
|
||||
//
|
||||
// // ==
|
||||
// {
|
||||
// let d_result = sks.unchecked_scalar_eq(&d_a, smaller_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) == U256::from(smaller_clear));
|
||||
//
|
||||
// let d_result = sks.unchecked_scalar_eq(&d_a, bigger_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) == bigger_clear);
|
||||
// }
|
||||
//
|
||||
// // !=
|
||||
// {
|
||||
// let d_result = sks.unchecked_scalar_ne(&d_a, smaller_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) != U256::from(smaller_clear));
|
||||
//
|
||||
// let d_result = sks.unchecked_scalar_ne(&d_a, bigger_clear, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) != bigger_clear);
|
||||
// }
|
||||
//
|
||||
// // Here the goal is to test, the branching
|
||||
// // made in the scalar sign function
|
||||
// //
|
||||
// // We are forcing one of the two branches to work on empty slices
|
||||
// {
|
||||
// let d_result = sks.unchecked_scalar_lt(&d_a, U256::ZERO, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) < U256::ZERO);
|
||||
//
|
||||
// let d_result = sks.unchecked_scalar_lt(&d_a, U256::MAX, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) < U256::MAX);
|
||||
//
|
||||
// // == (as it does not share same code)
|
||||
// let d_result = sks.unchecked_scalar_eq(&d_a, U256::ZERO, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) == U256::ZERO);
|
||||
//
|
||||
// // != (as it does not share same code)
|
||||
// let d_result = sks.unchecked_scalar_ne(&d_a, U256::MAX, &stream);
|
||||
// let result = d_result.to_boolean_block(&stream);
|
||||
// let decrypted = cks.decrypt_bool(&result);
|
||||
// assert_eq!(decrypted, U256::from(clear_a) != U256::MAX);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
create_gpu_parametrized_test!(integer_unchecked_scalar_min_u256);
|
||||
create_gpu_parametrized_test!(integer_unchecked_scalar_max_u256);
|
||||
create_gpu_parametrized_test!(integer_scalar_min_u256);
|
||||
@@ -90,3 +221,8 @@ define_gpu_scalar_comparison_test_functions!(lt, U256);
|
||||
define_gpu_scalar_comparison_test_functions!(le, U256);
|
||||
define_gpu_scalar_comparison_test_functions!(gt, U256);
|
||||
define_gpu_scalar_comparison_test_functions!(ge, U256);
|
||||
|
||||
//create_gpu_parametrized_test!(integer_unchecked_scalar_comparisons_edge {
|
||||
// PARAM_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
//});
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user