feat(gpu): implement signed scalar ge, gt, le, lt, max, and min

This commit is contained in:
Pedro Alves
2024-03-29 16:58:32 +01:00
committed by Agnès Leroy
parent 5df40597c2
commit 9576c5fd77
17 changed files with 1214 additions and 245 deletions

View File

@@ -36,7 +36,7 @@ enum COMPARISON_TYPE {
MAX = 6,
MIN = 7,
};
enum IS_RELATIONSHIP { IS_INFERIOR = 0, IS_EQUAL = 1, IS_SUPERIOR = 2 };
enum CMP_ORDERING { IS_INFERIOR = 0, IS_EQUAL = 1, IS_SUPERIOR = 2 };
extern "C" {
void scratch_cuda_full_propagation_64(
@@ -1846,6 +1846,8 @@ template <typename Torus> struct int_tree_sign_reduction_buffer {
bool allocate_gpu_memory) {
this->params = params;
Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus);
block_selector_f = [](Torus msb, Torus lsb) -> Torus {
if (msb == IS_EQUAL) // EQUAL
return lsb;
@@ -1854,13 +1856,8 @@ template <typename Torus> struct int_tree_sign_reduction_buffer {
};
if (allocate_gpu_memory) {
tmp_x = (Torus *)cuda_malloc_async((params.big_lwe_dimension + 1) *
num_radix_blocks * sizeof(Torus),
stream);
tmp_y = (Torus *)cuda_malloc_async((params.big_lwe_dimension + 1) *
num_radix_blocks * sizeof(Torus),
stream);
tmp_x = (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream);
tmp_y = (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream);
// LUTs
tree_inner_leaf_lut = new int_radix_lut<Torus>(
stream, params, 1, num_radix_blocks, allocate_gpu_memory);
@@ -1901,6 +1898,10 @@ template <typename Torus> struct int_comparison_diff_buffer {
int_tree_sign_reduction_buffer<Torus> *tree_buffer;
Torus *tmp_signs_a;
Torus *tmp_signs_b;
int_radix_lut<Torus> *reduce_signs_lut;
int_comparison_diff_buffer(cuda_stream_t *stream, COMPARISON_TYPE op,
int_radix_params params, uint32_t num_radix_blocks,
bool allocate_gpu_memory) {
@@ -1922,7 +1923,6 @@ template <typename Torus> struct int_comparison_diff_buffer {
return 42;
}
};
if (allocate_gpu_memory) {
Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus);
@@ -1935,15 +1935,26 @@ template <typename Torus> struct int_comparison_diff_buffer {
tree_buffer = new int_tree_sign_reduction_buffer<Torus>(
stream, operator_f, params, num_radix_blocks, allocate_gpu_memory);
tmp_signs_a =
(Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream);
tmp_signs_b =
(Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream);
// LUTs
reduce_signs_lut = new int_radix_lut<Torus>(
stream, params, 1, num_radix_blocks, allocate_gpu_memory);
}
}
void release(cuda_stream_t *stream) {
tree_buffer->release(stream);
delete tree_buffer;
reduce_signs_lut->release(stream);
delete reduce_signs_lut;
cuda_drop_async(tmp_packed_left, stream);
cuda_drop_async(tmp_packed_right, stream);
cuda_drop_async(tmp_signs_a, stream);
cuda_drop_async(tmp_signs_b, stream);
}
};
@@ -1963,6 +1974,7 @@ template <typename Torus> struct int_comparison_buffer {
Torus *tmp_block_comparisons;
Torus *tmp_lwe_array_out;
Torus *tmp_trivial_sign_block;
// Scalar EQ / NE
Torus *tmp_packed_input;
@@ -1975,6 +1987,7 @@ template <typename Torus> struct int_comparison_buffer {
bool is_signed;
// Used for scalar comparisons
int_radix_lut<Torus> *signed_msb_lut;
cuda_stream_t *lsb_stream;
cuda_stream_t *msb_stream;
@@ -1987,22 +2000,22 @@ template <typename Torus> struct int_comparison_buffer {
identity_lut_f = [](Torus x) -> Torus { return x; };
auto big_lwe_size = params.big_lwe_dimension + 1;
if (allocate_gpu_memory) {
lsb_stream = cuda_create_stream(stream->gpu_index);
msb_stream = cuda_create_stream(stream->gpu_index);
// +1 to have space for signed comparison
tmp_lwe_array_out = (Torus *)cuda_malloc_async(
(params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus),
stream);
big_lwe_size * (num_radix_blocks + 1) * sizeof(Torus), stream);
tmp_packed_input = (Torus *)cuda_malloc_async(
(params.big_lwe_dimension + 1) * 2 * num_radix_blocks * sizeof(Torus),
stream);
big_lwe_size * 2 * num_radix_blocks * sizeof(Torus), stream);
// Block comparisons
tmp_block_comparisons = (Torus *)cuda_malloc_async(
(params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus),
stream);
big_lwe_size * num_radix_blocks * sizeof(Torus), stream);
// Cleaning LUT
identity_lut = new int_radix_lut<Torus>(
@@ -2054,13 +2067,19 @@ template <typename Torus> struct int_comparison_buffer {
}
if (is_signed) {
tmp_trivial_sign_block =
(Torus *)cuda_malloc_async(big_lwe_size * sizeof(Torus), stream);
signed_lut =
new int_radix_lut<Torus>(stream, params, 1, 1, allocate_gpu_memory);
signed_msb_lut =
new int_radix_lut<Torus>(stream, params, 1, 1, allocate_gpu_memory);
auto message_modulus = (int)params.message_modulus;
uint32_t sign_bit_pos = log2(message_modulus) - 1;
std::function<Torus(Torus, Torus)> signed_lut_f;
signed_lut_f = [sign_bit_pos](Torus x, Torus y) -> Torus {
std::function<Torus(Torus, Torus)> signed_lut_f =
[sign_bit_pos](Torus x, Torus y) -> Torus {
auto x_sign_bit = x >> sign_bit_pos;
auto y_sign_bit = y >> sign_bit_pos;
@@ -2076,14 +2095,14 @@ template <typename Torus> struct int_comparison_buffer {
return (Torus)(IS_INFERIOR);
else if (x == y)
return (Torus)(IS_EQUAL);
else if (x > y)
else
return (Torus)(IS_SUPERIOR);
} else {
if (x < y)
return (Torus)(IS_SUPERIOR);
else if (x == y)
return (Torus)(IS_EQUAL);
else if (x > y)
else
return (Torus)(IS_INFERIOR);
}
PANIC("Cuda error: sign_lut creation failed due to wrong function.")
@@ -2126,8 +2145,11 @@ template <typename Torus> struct int_comparison_buffer {
cuda_drop_async(tmp_packed_input, stream);
if (is_signed) {
cuda_drop_async(tmp_trivial_sign_block, stream);
signed_lut->release(stream);
delete (signed_lut);
signed_msb_lut->release(stream);
delete (signed_msb_lut);
}
cuda_destroy_stream(lsb_stream);
cuda_destroy_stream(msb_stream);

View File

@@ -273,7 +273,7 @@ __host__ void host_compare_with_zero_equality(
remainder_blocks -= (chunk_size - 1);
// Update operands
chunk += chunk_size * big_lwe_size;
chunk += (chunk_size - 1) * big_lwe_size;
sum_i += big_lwe_size;
}
}

View File

@@ -587,7 +587,7 @@ __global__ void device_pack_blocks(Torus *lwe_array_out, Torus *lwe_array_in,
packed_block[tid] = lsb_block[tid] + factor * msb_block[tid];
}
if (num_radix_blocks % 2 != 0) {
if (num_radix_blocks % 2 == 1) {
// We couldn't pack the last block, so we just copy it
Torus *lsb_block =
lwe_array_in + (num_radix_blocks - 1) * (lwe_dimension + 1);
@@ -684,4 +684,91 @@ __host__ void extract_n_bits(cuda_stream_t *stream, Torus *lwe_array_out,
num_radix_blocks * bits_per_block, bit_extract->lut);
}
template <typename Torus>
__host__ void reduce_signs(cuda_stream_t *stream, Torus *signs_array_out,
Torus *signs_array_in,
int_comparison_buffer<Torus> *mem_ptr,
std::function<Torus(Torus)> sign_handler_f,
void *bsk, Torus *ksk, uint32_t num_sign_blocks) {
auto diff_buffer = mem_ptr->diff_buffer;
auto params = mem_ptr->params;
auto big_lwe_dimension = params.big_lwe_dimension;
auto glwe_dimension = params.glwe_dimension;
auto polynomial_size = params.polynomial_size;
auto message_modulus = params.message_modulus;
auto carry_modulus = params.carry_modulus;
std::function<Torus(Torus)> reduce_two_orderings_function =
[diff_buffer, sign_handler_f](Torus x) -> Torus {
int msb = (x >> 2) & 3;
int lsb = x & 3;
return diff_buffer->tree_buffer->block_selector_f(msb, lsb);
};
auto signs_a = diff_buffer->tmp_signs_a;
auto signs_b = diff_buffer->tmp_signs_b;
cuda_memcpy_async_gpu_to_gpu(
signs_a, signs_array_in,
(big_lwe_dimension + 1) * num_sign_blocks * sizeof(Torus), stream);
if (num_sign_blocks > 2) {
auto lut = diff_buffer->reduce_signs_lut;
generate_device_accumulator<Torus>(
stream, lut->lut, glwe_dimension, polynomial_size, message_modulus,
carry_modulus, reduce_two_orderings_function);
while (num_sign_blocks > 2) {
pack_blocks(stream, signs_b, signs_a, big_lwe_dimension, num_sign_blocks,
4);
integer_radix_apply_univariate_lookup_table_kb(
stream, signs_a, signs_b, bsk, ksk, num_sign_blocks / 2, lut);
auto last_block_signs_b =
signs_b + (num_sign_blocks / 2) * (big_lwe_dimension + 1);
auto last_block_signs_a =
signs_a + (num_sign_blocks / 2) * (big_lwe_dimension + 1);
if (num_sign_blocks % 2 == 1)
cuda_memcpy_async_gpu_to_gpu(last_block_signs_a, last_block_signs_b,
(big_lwe_dimension + 1) * sizeof(Torus),
stream);
num_sign_blocks = (num_sign_blocks / 2) + (num_sign_blocks % 2);
}
}
if (num_sign_blocks == 2) {
std::function<Torus(Torus)> final_lut_f =
[reduce_two_orderings_function, sign_handler_f](Torus x) -> Torus {
Torus final_sign = reduce_two_orderings_function(x);
return sign_handler_f(final_sign);
};
auto lut = diff_buffer->reduce_signs_lut;
generate_device_accumulator<Torus>(stream, lut->lut, glwe_dimension,
polynomial_size, message_modulus,
carry_modulus, final_lut_f);
pack_blocks(stream, signs_b, signs_a, big_lwe_dimension, 2, 4);
integer_radix_apply_univariate_lookup_table_kb(stream, signs_array_out,
signs_b, bsk, ksk, 1, lut);
} else {
std::function<Torus(Torus)> final_lut_f =
[mem_ptr, sign_handler_f](Torus x) -> Torus {
return sign_handler_f(x & 3);
};
auto lut = mem_ptr->diff_buffer->reduce_signs_lut;
generate_device_accumulator<Torus>(stream, lut->lut, glwe_dimension,
polynomial_size, message_modulus,
carry_modulus, final_lut_f);
integer_radix_apply_univariate_lookup_table_kb(stream, signs_array_out,
signs_a, bsk, ksk, 1, lut);
}
}
#endif // TFHE_RS_INTERNAL_INTEGER_CUH

View File

@@ -5,7 +5,7 @@
#include <omp.h>
template <typename Torus>
__host__ void host_integer_radix_scalar_difference_check_kb(
__host__ void integer_radix_unsigned_scalar_difference_check_kb(
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
Torus *scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
std::function<Torus(Torus)> sign_handler_f, void *bsk, Torus *ksk,
@@ -184,6 +184,344 @@ __host__ void host_integer_radix_scalar_difference_check_kb(
}
}
template <typename Torus>
__host__ void integer_radix_signed_scalar_difference_check_kb(
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
Torus *scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
std::function<Torus(Torus)> sign_handler_f, void *bsk, Torus *ksk,
uint32_t total_num_radix_blocks, uint32_t total_num_scalar_blocks) {
cudaSetDevice(stream->gpu_index);
auto params = mem_ptr->params;
auto big_lwe_dimension = params.big_lwe_dimension;
auto glwe_dimension = params.glwe_dimension;
auto polynomial_size = params.polynomial_size;
auto message_modulus = params.message_modulus;
auto carry_modulus = params.carry_modulus;
auto diff_buffer = mem_ptr->diff_buffer;
size_t big_lwe_size = big_lwe_dimension + 1;
// Reducing the signs is the bottleneck of the comparison algorithms,
// however if the scalar case there is an improvement:
//
// The idea is to reduce the number of signs block we have to
// reduce. We can do that by splitting the comparison problem in two parts.
//
// - One part where we compute the signs block between the scalar with just
// enough blocks
// from the ciphertext that can represent the scalar value
//
// - The other part is to compare the ciphertext blocks not considered for the
// sign
// computation with zero, and create a single sign block from that.
//
// The smaller the scalar value is compared to the ciphertext num bits
// encrypted, the more the comparisons with zeros we have to do, and the less
// signs block we will have to reduce.
//
// This will create a speedup as comparing a bunch of blocks with 0
// is faster
if (total_num_scalar_blocks == 0) {
// We only have to compare blocks with zero
// means scalar is zero
Torus *are_all_msb_zeros = mem_ptr->tmp_lwe_array_out;
host_compare_with_zero_equality(stream, are_all_msb_zeros, lwe_array_in,
mem_ptr, bsk, ksk, total_num_radix_blocks,
mem_ptr->is_zero_lut);
Torus *sign_block =
lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size;
auto sign_bit_pos = (int)std::log2(message_modulus) - 1;
auto scalar_last_leaf_with_respect_to_zero_lut_f =
[sign_handler_f, sign_bit_pos,
message_modulus](Torus sign_block) -> Torus {
sign_block %= message_modulus;
int sign_bit_is_set = (sign_block >> sign_bit_pos) == 1;
CMP_ORDERING sign_block_ordering;
if (sign_bit_is_set) {
sign_block_ordering = CMP_ORDERING::IS_INFERIOR;
} else if (sign_block != 0) {
sign_block_ordering = CMP_ORDERING::IS_SUPERIOR;
} else {
sign_block_ordering = CMP_ORDERING::IS_EQUAL;
}
return sign_block_ordering;
};
auto block_selector_f = mem_ptr->diff_buffer->tree_buffer->block_selector_f;
auto scalar_bivariate_last_leaf_lut_f =
[scalar_last_leaf_with_respect_to_zero_lut_f, sign_handler_f,
block_selector_f](Torus are_all_zeros, Torus sign_block) -> Torus {
// "re-code" are_all_zeros as an ordering value
if (are_all_zeros == 1) {
are_all_zeros = CMP_ORDERING::IS_EQUAL;
} else {
are_all_zeros = CMP_ORDERING::IS_SUPERIOR;
};
return sign_handler_f(block_selector_f(
scalar_last_leaf_with_respect_to_zero_lut_f(sign_block),
are_all_zeros));
};
auto lut = mem_ptr->diff_buffer->tree_buffer->tree_last_leaf_scalar_lut;
generate_device_accumulator_bivariate<Torus>(
stream, lut->lut, glwe_dimension, polynomial_size, message_modulus,
carry_modulus, scalar_bivariate_last_leaf_lut_f);
integer_radix_apply_bivariate_lookup_table_kb(
stream, lwe_array_out, are_all_msb_zeros, sign_block, bsk, ksk, 1, lut);
} else if (total_num_scalar_blocks < total_num_radix_blocks) {
// We have to handle both part of the work described above
// And the sign bit is located in the most_significant_blocks
uint32_t num_lsb_radix_blocks = total_num_scalar_blocks;
uint32_t num_msb_radix_blocks =
total_num_radix_blocks - num_lsb_radix_blocks;
auto msb = lwe_array_in + num_lsb_radix_blocks * big_lwe_size;
auto lwe_array_lsb_out = mem_ptr->tmp_lwe_array_out;
auto lwe_array_msb_out = lwe_array_lsb_out + big_lwe_size;
cuda_synchronize_stream(stream);
auto lsb_stream = mem_ptr->lsb_stream;
auto msb_stream = mem_ptr->msb_stream;
#pragma omp parallel sections
{
// Both sections may be executed in parallel
#pragma omp section
{
//////////////
// lsb
Torus *lhs = diff_buffer->tmp_packed_left;
Torus *rhs = diff_buffer->tmp_packed_right;
pack_blocks(lsb_stream, lhs, lwe_array_in, big_lwe_dimension,
num_lsb_radix_blocks, message_modulus);
pack_blocks(lsb_stream, rhs, scalar_blocks, 0, total_num_scalar_blocks,
message_modulus);
// From this point we have half number of blocks
num_lsb_radix_blocks /= 2;
num_lsb_radix_blocks += (total_num_scalar_blocks % 2);
// comparisons will be assigned
// - 0 if lhs < rhs
// - 1 if lhs == rhs
// - 2 if lhs > rhs
auto comparisons = mem_ptr->tmp_block_comparisons;
scalar_compare_radix_blocks_kb(lsb_stream, comparisons, lhs, rhs,
mem_ptr, bsk, ksk, num_lsb_radix_blocks);
// Reduces a vec containing radix blocks that encrypts a sign
// (inferior, equal, superior) to one single radix block containing the
// final sign
tree_sign_reduction(lsb_stream, lwe_array_lsb_out, comparisons,
mem_ptr->diff_buffer->tree_buffer,
mem_ptr->identity_lut_f, bsk, ksk,
num_lsb_radix_blocks);
}
#pragma omp section
{
//////////////
// msb
// We remove the last block (which is the sign)
Torus *are_all_msb_zeros = lwe_array_msb_out;
host_compare_with_zero_equality(msb_stream, are_all_msb_zeros, msb,
mem_ptr, bsk, ksk, num_msb_radix_blocks,
mem_ptr->is_zero_lut);
auto sign_bit_pos = (int)log2(message_modulus) - 1;
auto lut_f = [mem_ptr, sign_bit_pos](Torus sign_block,
Torus msb_are_zeros) {
bool sign_bit_is_set = (sign_block >> sign_bit_pos) == 1;
CMP_ORDERING sign_block_ordering;
if (sign_bit_is_set) {
sign_block_ordering = CMP_ORDERING::IS_INFERIOR;
} else if (sign_block != 0) {
sign_block_ordering = CMP_ORDERING::IS_SUPERIOR;
} else {
sign_block_ordering = CMP_ORDERING::IS_EQUAL;
}
CMP_ORDERING msb_ordering;
if (msb_are_zeros == 1)
msb_ordering = CMP_ORDERING::IS_EQUAL;
else
msb_ordering = CMP_ORDERING::IS_SUPERIOR;
return mem_ptr->diff_buffer->tree_buffer->block_selector_f(
sign_block_ordering, msb_ordering);
};
auto signed_msb_lut = mem_ptr->signed_msb_lut;
generate_device_accumulator_bivariate<Torus>(
msb_stream, signed_msb_lut->lut, params.glwe_dimension,
params.polynomial_size, params.message_modulus,
params.carry_modulus, lut_f);
Torus *sign_block = msb + (num_msb_radix_blocks - 1) * big_lwe_size;
integer_radix_apply_bivariate_lookup_table_kb(
msb_stream, lwe_array_msb_out, sign_block, are_all_msb_zeros, bsk,
ksk, 1, signed_msb_lut);
}
}
cuda_synchronize_stream(lsb_stream);
cuda_synchronize_stream(msb_stream);
//////////////
// Reduce the two blocks into one final
reduce_signs(stream, lwe_array_out, lwe_array_lsb_out, mem_ptr,
sign_handler_f, bsk, ksk, 2);
} else {
// We only have to do the regular comparison
// And not the part where we compare most significant blocks with zeros
// total_num_radix_blocks == total_num_scalar_blocks
uint32_t num_lsb_radix_blocks = total_num_radix_blocks;
cuda_synchronize_stream(stream);
auto lsb_stream = mem_ptr->lsb_stream;
auto msb_stream = mem_ptr->msb_stream;
auto lwe_array_ct_out = mem_ptr->tmp_lwe_array_out;
auto lwe_array_sign_out =
lwe_array_ct_out + (num_lsb_radix_blocks / 2) * big_lwe_size;
#pragma omp parallel sections
{
// Both sections may be executed in parallel
#pragma omp section
{
Torus *lhs = diff_buffer->tmp_packed_left;
Torus *rhs = diff_buffer->tmp_packed_right;
pack_blocks(lsb_stream, lhs, lwe_array_in, big_lwe_dimension,
num_lsb_radix_blocks - 1, message_modulus);
pack_blocks(lsb_stream, rhs, scalar_blocks, 0, num_lsb_radix_blocks - 1,
message_modulus);
// From this point we have half number of blocks
num_lsb_radix_blocks /= 2;
// comparisons will be assigned
// - 0 if lhs < rhs
// - 1 if lhs == rhs
// - 2 if lhs > rhs
scalar_compare_radix_blocks_kb(lsb_stream, lwe_array_ct_out, lhs, rhs,
mem_ptr, bsk, ksk, num_lsb_radix_blocks);
}
#pragma omp section
{
Torus *encrypted_sign_block =
lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size;
Torus *scalar_sign_block =
scalar_blocks + (total_num_scalar_blocks - 1);
auto trivial_sign_block = mem_ptr->tmp_trivial_sign_block;
create_trivial_radix(msb_stream, trivial_sign_block, scalar_sign_block,
big_lwe_dimension, 1, 1, message_modulus,
carry_modulus);
integer_radix_apply_bivariate_lookup_table_kb(
msb_stream, lwe_array_sign_out, encrypted_sign_block,
trivial_sign_block, bsk, ksk, 1, mem_ptr->signed_lut);
}
}
cuda_synchronize_stream(lsb_stream);
cuda_synchronize_stream(msb_stream);
// Reduces a vec containing radix blocks that encrypts a sign
// (inferior, equal, superior) to one single radix block containing the
// final sign
reduce_signs(stream, lwe_array_out, lwe_array_ct_out, mem_ptr,
sign_handler_f, bsk, ksk, num_lsb_radix_blocks + 1);
}
}
template <typename Torus>
__host__ void integer_radix_signed_scalar_maxmin_kb(
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
Torus *scalar_blocks, int_comparison_buffer<Torus> *mem_ptr, void *bsk,
Torus *ksk, uint32_t total_num_radix_blocks,
uint32_t total_num_scalar_blocks) {
cudaSetDevice(stream->gpu_index);
auto params = mem_ptr->params;
// Calculates the difference sign between the ciphertext and the scalar
// - 0 if lhs < rhs
// - 1 if lhs == rhs
// - 2 if lhs > rhs
auto sign = mem_ptr->tmp_lwe_array_out;
integer_radix_signed_scalar_difference_check_kb(
stream, sign, lwe_array_in, scalar_blocks, mem_ptr,
mem_ptr->identity_lut_f, bsk, ksk, total_num_radix_blocks,
total_num_scalar_blocks);
// There is no optimized CMUX for scalars, so we convert to a trivial
// ciphertext
auto lwe_array_left = lwe_array_in;
auto lwe_array_right = mem_ptr->tmp_block_comparisons;
create_trivial_radix(stream, lwe_array_right, scalar_blocks,
params.big_lwe_dimension, total_num_radix_blocks,
total_num_scalar_blocks, params.message_modulus,
params.carry_modulus);
// Selector
// CMUX for Max or Min
host_integer_radix_cmux_kb(stream, lwe_array_out, sign, lwe_array_left,
lwe_array_right, mem_ptr->cmux_buffer, bsk, ksk,
total_num_radix_blocks);
}
template <typename Torus>
__host__ void host_integer_radix_scalar_difference_check_kb(
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
Torus *scalar_blocks, int_comparison_buffer<Torus> *mem_ptr,
std::function<Torus(Torus)> sign_handler_f, void *bsk, Torus *ksk,
uint32_t total_num_radix_blocks, uint32_t total_num_scalar_blocks) {
if (mem_ptr->is_signed) {
// is signed and scalar is positive
integer_radix_signed_scalar_difference_check_kb(
stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr,
sign_handler_f, bsk, ksk, total_num_radix_blocks,
total_num_scalar_blocks);
} else {
integer_radix_unsigned_scalar_difference_check_kb(
stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr,
sign_handler_f, bsk, ksk, total_num_radix_blocks,
total_num_scalar_blocks);
}
}
template <typename Torus>
__host__ void host_integer_radix_signed_scalar_maxmin_kb(
cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in,
Torus *scalar_blocks, int_comparison_buffer<Torus> *mem_ptr, void *bsk,
Torus *ksk, uint32_t total_num_radix_blocks,
uint32_t total_num_scalar_blocks) {
if (mem_ptr->is_signed) {
// is signed and scalar is positive
integer_radix_signed_scalar_maxmin_kb(
stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, bsk, ksk,
total_num_radix_blocks, total_num_scalar_blocks);
} else {
integer_radix_unsigned_scalar_maxmin_kb(
stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, bsk, ksk,
total_num_radix_blocks, total_num_scalar_blocks);
}
}
template <typename Torus>
__host__ void
scalar_compare_radix_blocks_kb(cuda_stream_t *stream, Torus *lwe_array_out,

View File

@@ -1921,6 +1921,8 @@ mod cuda {
cuda_unchecked_scalar_right_shift,
cuda_unchecked_scalar_rotate_left,
cuda_unchecked_scalar_rotate_right,
cuda_unchecked_scalar_eq,
cuda_unchecked_scalar_ne,
cuda_unchecked_scalar_ge,
cuda_unchecked_scalar_gt,
cuda_unchecked_scalar_le,
@@ -1964,6 +1966,8 @@ mod cuda {
cuda_scalar_bitand,
cuda_scalar_bitor,
cuda_scalar_bitxor,
cuda_scalar_eq,
cuda_scalar_ne,
cuda_scalar_ge,
cuda_scalar_gt,
cuda_scalar_le,

View File

@@ -1795,6 +1795,54 @@ mod cuda {
rng_func: shift_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_eq,
display_name: eq,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_ne,
display_name: ne,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_gt,
display_name: gt,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_ge,
display_name: ge,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_lt,
display_name: lt,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_le,
display_name: le,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_min,
display_name: min,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_max,
display_name: max,
rng_func: default_signed_scalar
);
//===========================================
// Default
//===========================================
@@ -1959,6 +2007,54 @@ mod cuda {
rng_func: shift_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_eq,
display_name: eq,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_ne,
display_name: ne,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_gt,
display_name: gt,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_ge,
display_name: ge,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_lt,
display_name: lt,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_le,
display_name: le,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_min,
display_name: min,
rng_func: default_signed_scalar
);
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_max,
display_name: max,
rng_func: default_signed_scalar
);
criterion_group!(
unchecked_cuda_ops,
cuda_unchecked_add,
@@ -1995,6 +2091,14 @@ mod cuda {
cuda_unchecked_scalar_right_shift,
cuda_unchecked_scalar_rotate_left,
cuda_unchecked_scalar_rotate_right,
cuda_unchecked_scalar_eq,
cuda_unchecked_scalar_ne,
cuda_unchecked_scalar_gt,
cuda_unchecked_scalar_ge,
cuda_unchecked_scalar_lt,
cuda_unchecked_scalar_le,
cuda_unchecked_scalar_min,
cuda_unchecked_scalar_max,
);
criterion_group!(
@@ -2034,6 +2138,14 @@ mod cuda {
cuda_scalar_right_shift,
cuda_scalar_rotate_left,
cuda_scalar_rotate_right,
cuda_scalar_eq,
cuda_scalar_ne,
cuda_scalar_gt,
cuda_scalar_ge,
cuda_scalar_lt,
cuda_scalar_le,
cuda_scalar_min,
cuda_scalar_max,
);
}

View File

@@ -337,7 +337,7 @@ impl FheEq<bool> for FheBool {
let inner =
cuda_key
.key
.scalar_eq(&self.ciphertext.on_gpu(), u8::from(other), stream);
.scalar_eq(&*self.ciphertext.on_gpu(), u8::from(other), stream);
InnerBoolean::Cuda(inner)
}),
});
@@ -376,7 +376,7 @@ impl FheEq<bool> for FheBool {
let inner =
cuda_key
.key
.scalar_ne(&self.ciphertext.on_gpu(), u8::from(other), stream);
.scalar_ne(&*self.ciphertext.on_gpu(), u8::from(other), stream);
InnerBoolean::Cuda(inner)
}),
});

View File

@@ -6,6 +6,8 @@ use crate::high_level_api::global_state::with_thread_local_cuda_stream;
use crate::high_level_api::keys::InternalServerKey;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::BooleanBlock;
use crate::prelude::{FheDecrypt, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt};
use crate::shortint::ciphertext::Degree;
@@ -104,11 +106,12 @@ impl FheTryTrivialEncrypt<bool> for FheBool {
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner = cuda_key
.key
.create_trivial_radix(u64::from(value), 1, stream);
let inner: CudaUnsignedRadixCiphertext =
cuda_key
.key
.create_trivial_radix(u64::from(value), 1, stream);
InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext(
inner.ciphertext,
inner.into_inner(),
))
}),
});

View File

@@ -5,6 +5,8 @@ use crate::high_level_api::global_state::with_thread_local_cuda_stream;
use crate::high_level_api::integers::FheUintId;
use crate::high_level_api::keys::InternalServerKey;
use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom};
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use crate::prelude::{FheDecrypt, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt};
use crate::{ClientKey, CompactPublicKey, CompressedPublicKey, FheUint, PublicKey};
@@ -133,7 +135,7 @@ where
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner = cuda_key.key.create_trivial_radix(
let inner: CudaUnsignedRadixCiphertext = cuda_key.key.create_trivial_radix(
value,
Id::num_blocks(cuda_key.key.message_modulus),
stream,

View File

@@ -17,6 +17,8 @@ use crate::high_level_api::traits::{
};
use crate::integer::block_decomposition::DecomposableInto;
use crate::integer::ciphertext::IntegerCiphertext;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use crate::integer::U256;
use crate::FheBool;
use std::ops::{
@@ -60,7 +62,7 @@ where
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner_result = cuda_key
.key
.scalar_eq(&self.ciphertext.on_gpu(), rhs, stream);
.scalar_eq(&*self.ciphertext.on_gpu(), rhs, stream);
FheBool::new(inner_result)
}),
})
@@ -97,7 +99,7 @@ where
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner_result = cuda_key
.key
.scalar_ne(&self.ciphertext.on_gpu(), rhs, stream);
.scalar_ne(&*self.ciphertext.on_gpu(), rhs, stream);
FheBool::new(inner_result)
}),
})
@@ -140,7 +142,7 @@ where
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner_result = cuda_key
.key
.scalar_lt(&self.ciphertext.on_gpu(), rhs, stream);
.scalar_lt(&*self.ciphertext.on_gpu(), rhs, stream);
FheBool::new(inner_result)
}),
})
@@ -177,7 +179,7 @@ where
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner_result = cuda_key
.key
.scalar_le(&self.ciphertext.on_gpu(), rhs, stream);
.scalar_le(&*self.ciphertext.on_gpu(), rhs, stream);
FheBool::new(inner_result)
}),
})
@@ -214,7 +216,7 @@ where
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner_result = cuda_key
.key
.scalar_gt(&self.ciphertext.on_gpu(), rhs, stream);
.scalar_gt(&*self.ciphertext.on_gpu(), rhs, stream);
FheBool::new(inner_result)
}),
})
@@ -251,7 +253,7 @@ where
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner_result = cuda_key
.key
.scalar_ge(&self.ciphertext.on_gpu(), rhs, stream);
.scalar_ge(&*self.ciphertext.on_gpu(), rhs, stream);
FheBool::new(inner_result)
}),
})
@@ -296,7 +298,7 @@ where
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner_result = cuda_key
.key
.scalar_max(&self.ciphertext.on_gpu(), rhs, stream);
.scalar_max(&*self.ciphertext.on_gpu(), rhs, stream);
Self::new(inner_result)
}),
})
@@ -341,7 +343,7 @@ where
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner_result = cuda_key
.key
.scalar_min(&self.ciphertext.on_gpu(), rhs, stream);
.scalar_min(&*self.ciphertext.on_gpu(), rhs, stream);
Self::new(inner_result)
}),
})
@@ -1036,7 +1038,8 @@ generic_integer_impl_scalar_left_operation!(
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
with_thread_local_cuda_stream(|stream| {
let mut result = cuda_key.key.create_trivial_radix(lhs, rhs.ciphertext.on_gpu().ciphertext.info.blocks.len(), stream);
let mut result: CudaUnsignedRadixCiphertext = cuda_key.key.create_trivial_radix(
lhs, rhs.ciphertext.on_gpu().ciphertext.info.blocks.len(), stream);
cuda_key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(), stream);
RadixCiphertext::Cuda(result)
})

View File

@@ -19,6 +19,8 @@ pub trait CudaIntegerRadixCiphertext: Sized {
Self::from(self.as_ref().duplicate(stream))
}
fn into_inner(self) -> CudaRadixCiphertext;
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
@@ -67,6 +69,10 @@ impl CudaIntegerRadixCiphertext for CudaUnsignedRadixCiphertext {
fn from(ct: CudaRadixCiphertext) -> Self {
Self { ciphertext: ct }
}
fn into_inner(self) -> CudaRadixCiphertext {
self.ciphertext
}
}
impl CudaIntegerRadixCiphertext for CudaSignedRadixCiphertext {
@@ -83,6 +89,10 @@ impl CudaIntegerRadixCiphertext for CudaSignedRadixCiphertext {
fn from(ct: CudaRadixCiphertext) -> Self {
Self { ciphertext: ct }
}
fn into_inner(self) -> CudaRadixCiphertext {
self.ciphertext
}
}
impl CudaUnsignedRadixCiphertext {

View File

@@ -1142,7 +1142,7 @@ impl CudaStream {
num_blocks: u32,
num_scalar_blocks: u32,
op: ComparisonType,
is_signed: bool,
signed_with_positive_scalar: bool,
) {
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
scratch_cuda_integer_radix_comparison_kb_64(
@@ -1162,7 +1162,7 @@ impl CudaStream {
carry_modulus.0 as u32,
PBSType::Classical as u32,
op as u32,
is_signed,
signed_with_positive_scalar,
true,
);
@@ -1209,7 +1209,7 @@ impl CudaStream {
num_blocks: u32,
num_scalar_blocks: u32,
op: ComparisonType,
is_signed: bool,
signed_with_positive_scalar: bool,
) {
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
scratch_cuda_integer_radix_comparison_kb_64(
@@ -1229,7 +1229,7 @@ impl CudaStream {
carry_modulus.0 as u32,
PBSType::MultiBit as u32,
op as u32,
is_signed,
signed_with_positive_scalar,
true,
);
cuda_scalar_comparison_integer_radix_ciphertext_kb_64(

View File

@@ -71,7 +71,7 @@ impl CudaServerKey {
num_blocks: usize,
stream: &CudaStream,
) -> T {
T::from(self.create_trivial_radix(0, num_blocks, stream).ciphertext)
self.create_trivial_radix(0, num_blocks, stream)
}
/// Create a trivial ciphertext on the GPU
@@ -80,6 +80,7 @@ impl CudaServerKey {
///
/// ```rust
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::integer::{gen_keys_radix, RadixCiphertext};
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
@@ -93,20 +94,22 @@ impl CudaServerKey {
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream);
///
/// let d_ctxt = sks.create_trivial_radix(212u64, num_blocks, &mut stream);
/// let d_ctxt: CudaUnsignedRadixCiphertext =
/// sks.create_trivial_radix(212u64, num_blocks, &mut stream);
/// let ctxt = d_ctxt.to_radix_ciphertext(&mut stream);
///
/// // Decrypt:
/// let dec: u64 = cks.decrypt(&ctxt);
/// assert_eq!(212, dec);
/// ```
pub fn create_trivial_radix<Scalar>(
pub fn create_trivial_radix<Scalar, T>(
&self,
scalar: Scalar,
num_blocks: usize,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
) -> T
where
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let lwe_size = match self.pbs_order {
@@ -140,12 +143,10 @@ impl CudaServerKey {
let d_blocks = CudaLweCiphertextList::from_lwe_ciphertext_list(&cpu_lwe_list, stream);
CudaUnsignedRadixCiphertext {
ciphertext: CudaRadixCiphertext {
d_blocks,
info: CudaRadixCiphertextInfo { blocks: info },
},
}
T::from(CudaRadixCiphertext {
d_blocks,
info: CudaRadixCiphertextInfo { blocks: info },
})
}
/// # Safety
@@ -268,7 +269,7 @@ impl CudaServerKey {
///
///```rust
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
/// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext;
/// use tfhe::integer::gpu::ciphertext::{CudaRadixCiphertext, CudaUnsignedRadixCiphertext};
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::integer::IntegerCiphertext;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
@@ -282,7 +283,8 @@ impl CudaServerKey {
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream);
///
/// let mut d_ct1 = sks.create_trivial_radix(7u64, num_blocks, &mut stream);
/// let mut d_ct1: CudaUnsignedRadixCiphertext =
/// sks.create_trivial_radix(7u64, num_blocks, &mut stream);
/// let ct1 = d_ct1.to_radix_ciphertext(&mut stream);
/// assert_eq!(ct1.blocks().len(), 4);
///
@@ -342,6 +344,7 @@ impl CudaServerKey {
///
///```rust
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::integer::IntegerCiphertext;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
@@ -355,7 +358,8 @@ impl CudaServerKey {
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream);
///
/// let mut d_ct1 = sks.create_trivial_radix(7u64, num_blocks, &mut stream);
/// let mut d_ct1: CudaUnsignedRadixCiphertext =
/// sks.create_trivial_radix(7u64, num_blocks, &mut stream);
/// let ct1 = d_ct1.to_radix_ciphertext(&mut stream);
/// assert_eq!(ct1.blocks().len(), 4);
///
@@ -406,6 +410,7 @@ impl CudaServerKey {
///
///```rust
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::integer::IntegerCiphertext;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
@@ -419,7 +424,8 @@ impl CudaServerKey {
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream);
///
/// let mut d_ct1 = sks.create_trivial_radix(119u64, num_blocks, &mut stream);
/// let mut d_ct1: CudaUnsignedRadixCiphertext =
/// sks.create_trivial_radix(119u64, num_blocks, &mut stream);
/// let ct1 = d_ct1.to_radix_ciphertext(&mut stream);
/// assert_eq!(ct1.blocks().len(), 4);
///
@@ -470,6 +476,7 @@ impl CudaServerKey {
///
///```rust
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::integer::IntegerCiphertext;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
@@ -483,7 +490,8 @@ impl CudaServerKey {
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream);
///
/// let mut d_ct1 = sks.create_trivial_radix(119u64, num_blocks, &mut stream);
/// let mut d_ct1: CudaUnsignedRadixCiphertext =
/// sks.create_trivial_radix(119u64, num_blocks, &mut stream);
/// let ct1 = d_ct1.to_radix_ciphertext(&mut stream);
/// assert_eq!(ct1.blocks().len(), 4);
///
@@ -532,6 +540,7 @@ impl CudaServerKey {
///
///```rust
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::integer::IntegerCiphertext;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
@@ -546,7 +555,8 @@ impl CudaServerKey {
///
/// let msg = 2u8;
///
/// let mut d_ct1 = sks.create_trivial_radix(msg, num_blocks, &mut stream);
/// let mut d_ct1: CudaUnsignedRadixCiphertext =
/// sks.create_trivial_radix(msg, num_blocks, &mut stream);
/// let ct1 = d_ct1.to_radix_ciphertext(&mut stream);
/// assert_eq!(ct1.blocks().len(), 4);
///

View File

@@ -5,36 +5,124 @@ use crate::core_crypto::prelude::{CiphertextModulus, LweCiphertextCount};
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::info::CudaRadixCiphertextInfo;
use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::ComparisonType;
use crate::integer::server_key::comparator::Comparator;
use crate::shortint::ciphertext::Degree;
impl CudaServerKey {
/// Returns whether the clear scalar is outside of the
/// value range the ciphertext can hold.
///
/// - Returns None if the scalar is in the range of values that the ciphertext can represent
///
/// - Returns Some(ordering) when the scalar is out of representable range of the ciphertext.
/// - Equal will never be returned
/// - Less means the scalar is less than the min value representable by the ciphertext
/// - Greater means the scalar is greater that the max value representable by the ciphertext
pub(crate) fn is_scalar_out_of_bounds<T, Scalar>(
&self,
ct: &T,
scalar: Scalar,
) -> Option<std::cmp::Ordering>
where
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let scalar_blocks =
BlockDecomposer::with_early_stop_at_zero(scalar, self.message_modulus.0.ilog2())
.iter_as::<u64>()
.collect::<Vec<_>>();
let ct_len = ct.as_ref().d_blocks.lwe_ciphertext_count();
if T::IS_SIGNED {
let sign_bit_pos = self.message_modulus.0.ilog2() - 1;
let sign_bit_is_set = scalar_blocks
.get(ct_len.0 - 1)
.map_or(false, |block| (block >> sign_bit_pos) == 1);
if scalar > Scalar::ZERO
&& (scalar_blocks.len() > ct_len.0
|| (scalar_blocks.len() == ct_len.0 && sign_bit_is_set))
{
// If scalar is positive and that any bits above the ct's n-1 bits is set
// it means scalar is bigger.
//
// This is checked in two step
// - If there a more scalar blocks than ct blocks then ct is trivially bigger
// - If there are the same number of blocks but the "sign bit" / msb of st scalar is
// set then, the scalar is trivially bigger
return Some(std::cmp::Ordering::Greater);
} else if scalar < Scalar::ZERO {
// If scalar is negative, and that any bits above the ct's n-1 bits is not set
// it means scalar is smaller.
if ct_len.0 > scalar_blocks.len() {
// Ciphertext has more blocks, the scalar may be in range
return None;
}
// (returns false for empty iter)
let at_least_one_block_is_not_full_of_1s = scalar_blocks[ct_len.0..]
.iter()
.any(|&scalar_block| scalar_block != (self.message_modulus.0 as u64 - 1));
let sign_bit_pos = self.message_modulus.0.ilog2() - 1;
let sign_bit_is_unset = scalar_blocks
.get(ct_len.0 - 1)
.map_or(false, |block| (block >> sign_bit_pos) == 0);
if at_least_one_block_is_not_full_of_1s || sign_bit_is_unset {
// Scalar is smaller than lowest value of T
return Some(std::cmp::Ordering::Less);
}
}
} else {
// T is unsigned
if scalar < Scalar::ZERO {
// ct represent an unsigned (always >= 0)
return Some(std::cmp::Ordering::Less);
} else if scalar > Scalar::ZERO {
// scalar is obviously bigger if it has non-zero
// blocks after lhs's last block
let is_scalar_obviously_bigger =
scalar_blocks.get(ct_len.0..).is_some_and(|sub_slice| {
sub_slice.iter().any(|&scalar_block| scalar_block != 0)
});
if is_scalar_obviously_bigger {
return Some(std::cmp::Ordering::Greater);
}
}
}
None
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_comparison_async<T>(
pub unsafe fn unchecked_signed_and_unsigned_scalar_comparison_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
op: ComparisonType,
signed_with_positive_scalar: bool,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
if scalar < T::ZERO {
if scalar < Scalar::ZERO {
// ct represents an unsigned (always >= 0)
let ct_res = self.create_trivial_radix(Comparator::IS_SUPERIOR, 1, stream);
return CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(
ct_res.ciphertext.d_blocks,
ct_res.ciphertext.info,
));
let value = match op {
ComparisonType::GT | ComparisonType::GE => 1,
_ => 0,
};
let ct_res: T = self.create_trivial_radix(value, 1, stream);
return CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner());
}
let message_modulus = self.message_modulus.0;
@@ -50,11 +138,12 @@ impl CudaServerKey {
.get(ct.as_ref().d_blocks.lwe_ciphertext_count().0..)
.is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0));
if is_scalar_obviously_bigger {
let ct_res = self.create_trivial_radix(Comparator::IS_INFERIOR, 1, stream);
return CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(
ct_res.ciphertext.d_blocks,
ct_res.ciphertext.info,
));
let value = match op {
ComparisonType::LT | ComparisonType::LE => 1,
_ => 0,
};
let ct_res: T = self.create_trivial_radix(value, 1, stream);
return CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner());
}
// If we are still here, that means scalar_blocks above
@@ -105,7 +194,7 @@ impl CudaServerKey {
lwe_ciphertext_count.0 as u32,
scalar_blocks.len() as u32,
op,
false,
signed_with_positive_scalar,
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
@@ -133,7 +222,7 @@ impl CudaServerKey {
lwe_ciphertext_count.0 as u32,
scalar_blocks.len() as u32,
op,
false,
signed_with_positive_scalar,
);
}
}
@@ -145,49 +234,97 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_minmax_async<Scalar>(
pub unsafe fn unchecked_scalar_comparison_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
ct: &T,
scalar: Scalar,
op: ComparisonType,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
if scalar < Scalar::ZERO {
// ct represents an unsigned (always >= 0)
return self.create_trivial_radix(Comparator::IS_SUPERIOR, 1, stream);
}
let num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0;
if T::IS_SIGNED {
match self.is_scalar_out_of_bounds(ct, scalar) {
Some(std::cmp::Ordering::Greater) => {
// Scalar is greater than the bounds, so ciphertext is smaller
let result: T = match op {
ComparisonType::LT | ComparisonType::LE => {
self.create_trivial_radix(1, num_blocks, stream)
}
_ => self.create_trivial_radix(
0,
ct.as_ref().d_blocks.lwe_ciphertext_count().0,
stream,
),
};
return CudaBooleanBlock::from_cuda_radix_ciphertext(result.into_inner());
}
Some(std::cmp::Ordering::Less) => {
// Scalar is smaller than the bounds, so ciphertext is bigger
let result: T = match op {
ComparisonType::GT | ComparisonType::GE => {
self.create_trivial_radix(1, num_blocks, stream)
}
_ => self.create_trivial_radix(
0,
ct.as_ref().d_blocks.lwe_ciphertext_count().0,
stream,
),
};
return CudaBooleanBlock::from_cuda_radix_ciphertext(result.into_inner());
}
Some(std::cmp::Ordering::Equal) => unreachable!("Internal error: invalid value"),
None => {
// scalar is in range, fallthrough
}
}
if scalar >= Scalar::ZERO {
self.unchecked_signed_and_unsigned_scalar_comparison_async(
ct, scalar, op, true, stream,
)
} else {
let scalar_as_trivial = self.create_trivial_radix(scalar, num_blocks, stream);
self.unchecked_comparison_async(ct, &scalar_as_trivial, op, stream)
}
} else {
// Unsigned
self.unchecked_signed_and_unsigned_scalar_comparison_async(
ct, scalar, op, false, stream,
)
}
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_minmax_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
op: ComparisonType,
stream: &CudaStream,
) -> T
where
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let message_modulus = self.message_modulus.0;
let mut scalar_blocks =
let scalar_blocks =
BlockDecomposer::with_early_stop_at_zero(scalar, message_modulus.ilog2())
.iter_as::<u64>()
.collect::<Vec<_>>();
// scalar is obviously bigger if it has non-zero
// blocks after lhs's last block
let is_scalar_obviously_bigger = scalar_blocks
.get(ct.as_ref().d_blocks.lwe_ciphertext_count().0..)
.is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0));
if is_scalar_obviously_bigger {
return self.create_trivial_radix(Comparator::IS_INFERIOR, 1, stream);
}
// If we are still here, that means scalar_blocks above
// num_blocks are 0s, we can remove them
// as we will handle them separately.
scalar_blocks.truncate(ct.as_ref().d_blocks.lwe_ciphertext_count().0);
let d_scalar_blocks: CudaVec<u64> = CudaVec::from_cpu_async(&scalar_blocks, stream);
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let mut result = CudaUnsignedRadixCiphertext {
ciphertext: ct.as_ref().duplicate_async(stream),
};
let mut result = ct.duplicate_async(stream);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
@@ -214,7 +351,7 @@ impl CudaServerKey {
lwe_ciphertext_count.0 as u32,
scalar_blocks.len() as u32,
op,
false,
T::IS_SIGNED,
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
@@ -242,7 +379,7 @@ impl CudaServerKey {
lwe_ciphertext_count.0 as u32,
scalar_blocks.len() as u32,
op,
false,
T::IS_SIGNED,
);
}
}
@@ -254,26 +391,28 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_eq_async<Scalar>(
pub unsafe fn unchecked_scalar_eq_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::EQ, stream)
}
pub fn unchecked_scalar_eq<T>(
pub fn unchecked_scalar_eq<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let result = unsafe { self.unchecked_scalar_eq_async(ct, scalar, stream) };
stream.synchronize();
@@ -284,14 +423,15 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn scalar_eq_async<T>(
pub unsafe fn scalar_eq_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
@@ -346,14 +486,15 @@ impl CudaServerKey {
/// let dec_result = cks.decrypt_bool(&ct_res);
/// assert_eq!(dec_result, msg1 == msg2);
/// ```
pub fn scalar_eq<T>(
pub fn scalar_eq<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let result = unsafe { self.scalar_eq_async(ct, scalar, stream) };
stream.synchronize();
@@ -364,14 +505,15 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn scalar_ne_async<T>(
pub unsafe fn scalar_ne_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
@@ -426,14 +568,15 @@ impl CudaServerKey {
/// let dec_result = cks.decrypt_bool(&ct_res);
/// assert_eq!(dec_result, msg1 != msg2);
/// ```
pub fn scalar_ne<T>(
pub fn scalar_ne<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_ne_async(ct, scalar, stream) };
stream.synchronize();
@@ -444,26 +587,28 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_ne_async<T>(
pub unsafe fn unchecked_scalar_ne_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::NE, stream)
}
pub fn unchecked_scalar_ne<T>(
pub fn unchecked_scalar_ne<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let result = unsafe { self.unchecked_scalar_ne_async(ct, scalar, stream) };
stream.synchronize();
@@ -474,26 +619,28 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_gt_async<T>(
pub unsafe fn unchecked_scalar_gt_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::GT, stream)
}
pub fn unchecked_scalar_gt<T>(
pub fn unchecked_scalar_gt<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_gt_async(ct, scalar, stream) };
stream.synchronize();
@@ -504,26 +651,28 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_ge_async<T>(
pub unsafe fn unchecked_scalar_ge_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::GE, stream)
}
pub fn unchecked_scalar_ge<T>(
pub fn unchecked_scalar_ge<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_ge_async(ct, scalar, stream) };
stream.synchronize();
@@ -534,26 +683,28 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_lt_async<T>(
pub unsafe fn unchecked_scalar_lt_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::LT, stream)
}
pub fn unchecked_scalar_lt<T>(
pub fn unchecked_scalar_lt<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_lt_async(ct, scalar, stream) };
stream.synchronize();
@@ -564,26 +715,28 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_le_async<T>(
pub unsafe fn unchecked_scalar_le_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::LE, stream)
}
pub fn unchecked_scalar_le<T>(
pub fn unchecked_scalar_le<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_le_async(ct, scalar, stream) };
stream.synchronize();
@@ -593,14 +746,15 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn scalar_gt_async<T>(
pub unsafe fn scalar_gt_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
@@ -614,14 +768,15 @@ impl CudaServerKey {
self.unchecked_scalar_gt_async(lhs, scalar, stream)
}
pub fn scalar_gt<T>(
pub fn scalar_gt<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_gt_async(ct, scalar, stream) };
stream.synchronize();
@@ -632,14 +787,15 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn scalar_ge_async<T>(
pub unsafe fn scalar_ge_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
@@ -653,14 +809,15 @@ impl CudaServerKey {
self.unchecked_scalar_ge_async(lhs, scalar, stream)
}
pub fn scalar_ge<T>(
pub fn scalar_ge<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_ge_async(ct, scalar, stream) };
stream.synchronize();
@@ -671,14 +828,15 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn scalar_lt_async<T>(
pub unsafe fn scalar_lt_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
@@ -692,14 +850,15 @@ impl CudaServerKey {
self.unchecked_scalar_lt_async(lhs, scalar, stream)
}
pub fn scalar_lt<T>(
pub fn scalar_lt<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_lt_async(ct, scalar, stream) };
stream.synchronize();
@@ -709,14 +868,15 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn scalar_le_async<T>(
pub unsafe fn scalar_le_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
@@ -730,14 +890,15 @@ impl CudaServerKey {
self.unchecked_scalar_le_async(lhs, scalar, stream)
}
pub fn scalar_le<T>(
pub fn scalar_le<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaBooleanBlock
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_le_async(ct, scalar, stream) };
stream.synchronize();
@@ -748,26 +909,23 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_max_async<T>(
pub unsafe fn unchecked_scalar_max_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
) -> T
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_minmax_async(ct, scalar, ComparisonType::MAX, stream)
}
pub fn unchecked_scalar_max<T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn unchecked_scalar_max<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_max_async(ct, scalar, stream) };
stream.synchronize();
@@ -778,26 +936,23 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_min_async<Scalar>(
pub unsafe fn unchecked_scalar_min_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
) -> T
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_minmax_async(ct, scalar, ComparisonType::MIN, stream)
}
pub fn unchecked_scalar_min<Scalar>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: Scalar,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn unchecked_scalar_min<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_min_async(ct, scalar, stream) };
stream.synchronize();
@@ -808,14 +963,15 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn scalar_max_async<T>(
pub unsafe fn scalar_max_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
) -> T
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
@@ -829,14 +985,10 @@ impl CudaServerKey {
self.unchecked_scalar_max_async(lhs, scalar, stream)
}
pub fn scalar_max<T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn scalar_max<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_max_async(ct, scalar, stream) };
stream.synchronize();
@@ -847,14 +999,15 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn scalar_min_async<T>(
pub unsafe fn scalar_min_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
ct: &T,
scalar: Scalar,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
) -> T
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
@@ -868,14 +1021,10 @@ impl CudaServerKey {
self.unchecked_scalar_min_async(lhs, scalar, stream)
}
pub fn scalar_min<T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn scalar_min<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
where
T: DecomposableInto<u64>,
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_min_async(ct, scalar, stream) };
stream.synchronize();

View File

@@ -7,6 +7,7 @@ pub(crate) mod test_neg;
pub(crate) mod test_rotate;
pub(crate) mod test_scalar_add;
pub(crate) mod test_scalar_bitwise_op;
pub(crate) mod test_scalar_comparison;
pub(crate) mod test_scalar_mul;
pub(crate) mod test_scalar_rotate;
pub(crate) mod test_scalar_shift;

View File

@@ -0,0 +1,92 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parametrized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_comparison::{
test_signed_default_scalar_function, test_signed_default_scalar_minmax,
test_signed_unchecked_scalar_function, test_signed_unchecked_scalar_minmax,
};
use crate::shortint::parameters::*;
/// This macro generates the tests for a given comparison fn
///
/// All our comparison function have 2 variants:
/// - unchecked_$comparison_name
/// - $comparison_name
///
/// So, for example, for the `gt` comparison fn, this macro will generate the tests for
/// the 2 variants described above
macro_rules! define_gpu_signed_scalar_comparison_test_functions {
($comparison_name:ident, $clear_type:ty) => {
::paste::paste!{
fn [<integer_signed_unchecked_scalar_ $comparison_name _ $clear_type>]<P>(param: P) where P: Into<PBSParameters> {
let num_tests = 1;
let executor = GpuFunctionExecutor::new(&CudaServerKey::[<unchecked_scalar_ $comparison_name>]);
test_signed_unchecked_scalar_function(
param,
num_tests,
executor,
|lhs, rhs| $clear_type::from(<$clear_type>::$comparison_name(&lhs, &rhs)),
)
}
fn [<integer_signed_default_scalar_ $comparison_name $clear_type>]<P>(param: P) where P: Into<PBSParameters> {
let num_tests = 10;
let executor = GpuFunctionExecutor::new(&CudaServerKey::[<scalar_ $comparison_name>]);
test_signed_default_scalar_function(
param,
num_tests,
executor,
|lhs, rhs| $clear_type::from(<$clear_type>::$comparison_name(&lhs, &rhs)),
)
}
create_gpu_parametrized_test!([<integer_signed_unchecked_scalar_ $comparison_name _ $clear_type>]);
create_gpu_parametrized_test!([<integer_signed_default_scalar_ $comparison_name $clear_type>]);
}
};
}
fn integer_signed_unchecked_scalar_min_i128<P>(params: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_min);
test_signed_unchecked_scalar_minmax(params, 2, executor, std::cmp::min::<i128>);
}
fn integer_signed_unchecked_scalar_max_i128<P>(params: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_max);
test_signed_unchecked_scalar_minmax(params, 2, executor, std::cmp::max::<i128>);
}
fn integer_signed_scalar_min_i128<P>(params: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_min);
test_signed_default_scalar_minmax(params, 2, executor, std::cmp::min::<i128>);
}
fn integer_signed_scalar_max_i128<P>(params: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_max);
test_signed_default_scalar_minmax(params, 2, executor, std::cmp::max::<i128>);
}
create_gpu_parametrized_test!(integer_signed_unchecked_scalar_max_i128);
create_gpu_parametrized_test!(integer_signed_unchecked_scalar_min_i128);
create_gpu_parametrized_test!(integer_signed_scalar_max_i128);
create_gpu_parametrized_test!(integer_signed_scalar_min_i128);
define_gpu_signed_scalar_comparison_test_functions!(eq, i128);
define_gpu_signed_scalar_comparison_test_functions!(ne, i128);
define_gpu_signed_scalar_comparison_test_functions!(lt, i128);
define_gpu_signed_scalar_comparison_test_functions!(le, i128);
define_gpu_signed_scalar_comparison_test_functions!(gt, i128);
define_gpu_signed_scalar_comparison_test_functions!(ge, i128);

View File

@@ -79,6 +79,137 @@ where
test_default_scalar_minmax(params, 2, executor, std::cmp::max::<U256>);
}
// The goal of this function is to ensure that scalar comparisons
// work when the scalar type used is either bigger or smaller (in bit size)
// compared to the ciphertext
//fn integer_unchecked_scalar_comparisons_edge(param: ClassicPBSParameters) {
// let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
//
// let gpu_index = 0;
// let device = CudaDevice::new(gpu_index);
// let stream = CudaStream::new_unchecked(device);
//
// let (cks, sks) = gen_keys_gpu(param, &stream);
//
// let mut rng = rand::thread_rng();
//
// for _ in 0..4 {
// let clear_a = rng.gen_range((u128::from(u64::MAX) + 1)..=u128::MAX);
// let smaller_clear = rng.gen::<u64>();
// let bigger_clear = rng.gen::<U256>();
//
// let a = cks.encrypt_radix(clear_a, num_block);
// // Copy to the GPU
// let d_a = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&a, &stream);
//
// // >=
// {
// let d_result = sks.unchecked_scalar_ge(&d_a, smaller_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) >= U256::from(smaller_clear));
//
// let d_result = sks.unchecked_scalar_ge(&d_a, bigger_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) >= bigger_clear);
// }
//
// // >
// {
// let d_result = sks.unchecked_scalar_gt(&d_a, smaller_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) > U256::from(smaller_clear));
//
// let d_result = sks.unchecked_scalar_gt(&d_a, bigger_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) > bigger_clear);
// }
//
// // <=
// {
// let d_result = sks.unchecked_scalar_le(&d_a, smaller_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) <= U256::from(smaller_clear));
//
// let d_result = sks.unchecked_scalar_le(&d_a, bigger_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) <= bigger_clear);
// }
//
// // <
// {
// let d_result = sks.unchecked_scalar_lt(&d_a, smaller_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) < U256::from(smaller_clear));
//
// let d_result = sks.unchecked_scalar_lt(&d_a, bigger_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) < bigger_clear);
// }
//
// // ==
// {
// let d_result = sks.unchecked_scalar_eq(&d_a, smaller_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) == U256::from(smaller_clear));
//
// let d_result = sks.unchecked_scalar_eq(&d_a, bigger_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) == bigger_clear);
// }
//
// // !=
// {
// let d_result = sks.unchecked_scalar_ne(&d_a, smaller_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) != U256::from(smaller_clear));
//
// let d_result = sks.unchecked_scalar_ne(&d_a, bigger_clear, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) != bigger_clear);
// }
//
// // Here the goal is to test, the branching
// // made in the scalar sign function
// //
// // We are forcing one of the two branches to work on empty slices
// {
// let d_result = sks.unchecked_scalar_lt(&d_a, U256::ZERO, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) < U256::ZERO);
//
// let d_result = sks.unchecked_scalar_lt(&d_a, U256::MAX, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) < U256::MAX);
//
// // == (as it does not share same code)
// let d_result = sks.unchecked_scalar_eq(&d_a, U256::ZERO, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) == U256::ZERO);
//
// // != (as it does not share same code)
// let d_result = sks.unchecked_scalar_ne(&d_a, U256::MAX, &stream);
// let result = d_result.to_boolean_block(&stream);
// let decrypted = cks.decrypt_bool(&result);
// assert_eq!(decrypted, U256::from(clear_a) != U256::MAX);
// }
// }
//}
create_gpu_parametrized_test!(integer_unchecked_scalar_min_u256);
create_gpu_parametrized_test!(integer_unchecked_scalar_max_u256);
create_gpu_parametrized_test!(integer_scalar_min_u256);
@@ -90,3 +221,8 @@ define_gpu_scalar_comparison_test_functions!(lt, U256);
define_gpu_scalar_comparison_test_functions!(le, U256);
define_gpu_scalar_comparison_test_functions!(gt, U256);
define_gpu_scalar_comparison_test_functions!(ge, U256);
//create_gpu_parametrized_test!(integer_unchecked_scalar_comparisons_edge {
// PARAM_MESSAGE_2_CARRY_2_KS_PBS,
//});
//