mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 06:13:58 -05:00
chore(shortint): make shortint div behavior match integer on div by zero
This commit is contained in:
@@ -21,8 +21,8 @@ typedef int (*UnaryAssignCallback)(const ShortintServerKey *, ShortintCiphertext
|
||||
|
||||
void test_shortint_unary_op(const ShortintClientKey *cks, const ShortintServerKey *sks,
|
||||
const ShortintClientKey *cks_small, const ShortintServerKey *sks_small,
|
||||
const uint32_t message_bits, const uint32_t carry_bits,
|
||||
uint64_t (*c_fun)(uint64_t), UnaryCallback api_fun) {
|
||||
const uint32_t message_bits, uint64_t (*c_fun)(uint64_t),
|
||||
UnaryCallback api_fun) {
|
||||
|
||||
int message_max = 1 << message_bits;
|
||||
|
||||
@@ -70,8 +70,7 @@ void test_shortint_unary_op(const ShortintClientKey *cks, const ShortintServerKe
|
||||
void test_shortint_unary_op_assign(const ShortintClientKey *cks, const ShortintServerKey *sks,
|
||||
const ShortintClientKey *cks_small,
|
||||
const ShortintServerKey *sks_small, const uint32_t message_bits,
|
||||
const uint32_t carry_bits, uint64_t (*c_fun)(uint64_t),
|
||||
UnaryAssignCallback api_fun) {
|
||||
uint64_t (*c_fun)(uint64_t), UnaryAssignCallback api_fun) {
|
||||
|
||||
int message_max = 1 << message_bits;
|
||||
|
||||
@@ -116,8 +115,8 @@ void test_shortint_unary_op_assign(const ShortintClientKey *cks, const ShortintS
|
||||
|
||||
void test_shortint_binary_op(const ShortintClientKey *cks, const ShortintServerKey *sks,
|
||||
const ShortintClientKey *cks_small, const ShortintServerKey *sks_small,
|
||||
const uint32_t message_bits, const uint32_t carry_bits,
|
||||
uint64_t (*c_fun)(uint64_t, uint64_t), BinaryCallback api_fun) {
|
||||
const uint32_t message_bits, uint64_t (*c_fun)(uint64_t, uint64_t),
|
||||
BinaryCallback api_fun) {
|
||||
|
||||
int message_max = 1 << message_bits;
|
||||
|
||||
@@ -176,7 +175,7 @@ void test_shortint_binary_op(const ShortintClientKey *cks, const ShortintServerK
|
||||
void test_shortint_binary_op_assign(const ShortintClientKey *cks, const ShortintServerKey *sks,
|
||||
const ShortintClientKey *cks_small,
|
||||
const ShortintServerKey *sks_small, const uint32_t message_bits,
|
||||
const uint32_t carry_bits,
|
||||
|
||||
uint64_t (*c_fun)(uint64_t, uint64_t),
|
||||
BinaryAssignCallback api_fun) {
|
||||
|
||||
@@ -233,9 +232,135 @@ void test_shortint_binary_op_assign(const ShortintClientKey *cks, const Shortint
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t homomorphic_div(uint64_t left, uint64_t right, uint64_t value_on_div_by_zero) {
|
||||
if (right != 0) {
|
||||
return left / right;
|
||||
} else {
|
||||
// Special value chosen in the shortint implementation in case of a division by 0
|
||||
return value_on_div_by_zero;
|
||||
}
|
||||
}
|
||||
|
||||
void test_shortint_div(const ShortintClientKey *cks, const ShortintServerKey *sks,
|
||||
const ShortintClientKey *cks_small, const ShortintServerKey *sks_small,
|
||||
const uint32_t message_bits) {
|
||||
|
||||
int message_max = 1 << message_bits;
|
||||
|
||||
for (int is_big = 0; is_big < 2; ++is_big) {
|
||||
for (int val_left = 0; val_left < message_max; ++val_left) {
|
||||
for (int val_right = 0; val_right < message_max; ++val_right) {
|
||||
ShortintCiphertext *ct_left = NULL;
|
||||
ShortintCiphertext *ct_right = NULL;
|
||||
ShortintCiphertext *ct_result = NULL;
|
||||
const ShortintClientKey *cks_in_use = NULL;
|
||||
const ShortintServerKey *sks_in_use = NULL;
|
||||
|
||||
uint64_t left = (uint64_t)val_left;
|
||||
uint64_t right = (uint64_t)val_right;
|
||||
|
||||
uint64_t expected = homomorphic_div(left, right, (uint64_t)(message_max - 1)) % message_max;
|
||||
|
||||
if (is_big == 1) {
|
||||
cks_in_use = cks;
|
||||
sks_in_use = sks;
|
||||
|
||||
int encrypt_left_ok = shortint_client_key_encrypt(cks_in_use, left, &ct_left);
|
||||
assert(encrypt_left_ok == 0);
|
||||
|
||||
int encrypt_right_ok = shortint_client_key_encrypt(cks_in_use, right, &ct_right);
|
||||
assert(encrypt_right_ok == 0);
|
||||
} else {
|
||||
cks_in_use = cks_small;
|
||||
sks_in_use = sks_small;
|
||||
|
||||
int encrypt_left_ok = shortint_client_key_encrypt(cks_in_use, left, &ct_left);
|
||||
assert(encrypt_left_ok == 0);
|
||||
|
||||
int encrypt_right_ok = shortint_client_key_encrypt(cks_in_use, right, &ct_right);
|
||||
assert(encrypt_right_ok == 0);
|
||||
}
|
||||
|
||||
int api_call_ok =
|
||||
shortint_server_key_unchecked_div(sks_in_use, ct_left, ct_right, &ct_result);
|
||||
assert(api_call_ok == 0);
|
||||
|
||||
uint64_t decrypted_result = -1;
|
||||
|
||||
int decrypt_ok = shortint_client_key_decrypt(cks_in_use, ct_result, &decrypted_result);
|
||||
assert(decrypt_ok == 0);
|
||||
|
||||
assert(decrypted_result == expected);
|
||||
|
||||
shortint_destroy_ciphertext(ct_left);
|
||||
shortint_destroy_ciphertext(ct_right);
|
||||
shortint_destroy_ciphertext(ct_result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void test_shortint_div_assign(const ShortintClientKey *cks, const ShortintServerKey *sks,
|
||||
const ShortintClientKey *cks_small,
|
||||
const ShortintServerKey *sks_small, const uint32_t message_bits) {
|
||||
|
||||
int message_max = 1 << message_bits;
|
||||
|
||||
for (int is_big = 0; is_big < 2; ++is_big) {
|
||||
for (int val_left = 0; val_left < message_max; ++val_left) {
|
||||
for (int val_right = 0; val_right < message_max; ++val_right) {
|
||||
ShortintCiphertext *ct_left_and_result = NULL;
|
||||
ShortintCiphertext *ct_right = NULL;
|
||||
const ShortintClientKey *cks_in_use = NULL;
|
||||
const ShortintServerKey *sks_in_use = NULL;
|
||||
|
||||
uint64_t left = (uint64_t)val_left;
|
||||
uint64_t right = (uint64_t)val_right;
|
||||
|
||||
uint64_t expected = homomorphic_div(left, right, (uint64_t)(message_max - 1)) % message_max;
|
||||
|
||||
if (is_big == 1) {
|
||||
cks_in_use = cks;
|
||||
sks_in_use = sks;
|
||||
|
||||
int encrypt_left_ok = shortint_client_key_encrypt(cks_in_use, left, &ct_left_and_result);
|
||||
assert(encrypt_left_ok == 0);
|
||||
|
||||
int encrypt_right_ok = shortint_client_key_encrypt(cks_in_use, right, &ct_right);
|
||||
assert(encrypt_right_ok == 0);
|
||||
} else {
|
||||
cks_in_use = cks_small;
|
||||
sks_in_use = sks_small;
|
||||
|
||||
int encrypt_left_ok = shortint_client_key_encrypt(cks_in_use, left, &ct_left_and_result);
|
||||
assert(encrypt_left_ok == 0);
|
||||
|
||||
int encrypt_right_ok = shortint_client_key_encrypt(cks_in_use, right, &ct_right);
|
||||
assert(encrypt_right_ok == 0);
|
||||
}
|
||||
|
||||
int api_call_ok =
|
||||
shortint_server_key_unchecked_div_assign(sks_in_use, ct_left_and_result, ct_right);
|
||||
assert(api_call_ok == 0);
|
||||
|
||||
uint64_t decrypted_result = -1;
|
||||
|
||||
int decrypt_ok =
|
||||
shortint_client_key_decrypt(cks_in_use, ct_left_and_result, &decrypted_result);
|
||||
assert(decrypt_ok == 0);
|
||||
|
||||
assert(decrypted_result == expected);
|
||||
|
||||
shortint_destroy_ciphertext(ct_left_and_result);
|
||||
shortint_destroy_ciphertext(ct_right);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void test_shortint_binary_scalar_op(
|
||||
const ShortintClientKey *cks, const ShortintServerKey *sks, const ShortintClientKey *cks_small,
|
||||
const ShortintServerKey *sks_small, const uint32_t message_bits, const uint32_t carry_bits,
|
||||
const ShortintServerKey *sks_small, const uint32_t message_bits,
|
||||
uint64_t (*c_fun)(uint64_t, uint8_t),
|
||||
int (*api_fun)(const ShortintServerKey *, ShortintCiphertext *, uint8_t, ShortintCiphertext **),
|
||||
uint8_t forbidden_scalar_values[], size_t forbidden_scalar_values_len) {
|
||||
@@ -302,7 +427,7 @@ void test_shortint_binary_scalar_op(
|
||||
|
||||
void test_shortint_binary_scalar_op_assign(
|
||||
const ShortintClientKey *cks, const ShortintServerKey *sks, const ShortintClientKey *cks_small,
|
||||
const ShortintServerKey *sks_small, const uint32_t message_bits, const uint32_t carry_bits,
|
||||
const ShortintServerKey *sks_small, const uint32_t message_bits,
|
||||
uint64_t (*c_fun)(uint64_t, uint8_t),
|
||||
int (*api_fun)(const ShortintServerKey *, ShortintCiphertext *, uint8_t),
|
||||
uint8_t forbidden_scalar_values[], size_t forbidden_scalar_values_len) {
|
||||
@@ -371,15 +496,6 @@ uint64_t sub(uint64_t left, uint64_t right) { return left - right; }
|
||||
uint64_t mul(uint64_t left, uint64_t right) { return left * right; }
|
||||
uint64_t neg(uint64_t in) { return -in; }
|
||||
|
||||
uint64_t homomorphic_div(uint64_t left, uint64_t right) {
|
||||
if (right != 0) {
|
||||
return left / right;
|
||||
} else {
|
||||
// Special value chosen in the shortint implementation in case of a division by 0
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t bitand(uint64_t left, uint64_t right) { return left & right; }
|
||||
uint64_t bitxor(uint64_t left, uint64_t right) { return left ^ right; }
|
||||
uint64_t bitor (uint64_t left, uint64_t right) { return left | right; }
|
||||
@@ -473,258 +589,237 @@ void test_server_key(void) {
|
||||
assert(deser_sks_ok == 0);
|
||||
|
||||
printf("add\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, add,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, add,
|
||||
(BinaryCallback)shortint_server_key_smart_add);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, add,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, add,
|
||||
(BinaryCallback)shortint_server_key_unchecked_add);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, add,
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, add,
|
||||
(BinaryAssignCallback)shortint_server_key_smart_add_assign);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, add,
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, add,
|
||||
(BinaryAssignCallback)shortint_server_key_unchecked_add_assign);
|
||||
|
||||
printf("sub\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, sub,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, sub,
|
||||
(BinaryCallback)shortint_server_key_smart_sub);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, sub,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, sub,
|
||||
(BinaryCallback)shortint_server_key_unchecked_sub);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, sub,
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, sub,
|
||||
(BinaryAssignCallback)shortint_server_key_smart_sub_assign);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, sub,
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, sub,
|
||||
(BinaryAssignCallback)shortint_server_key_unchecked_sub_assign);
|
||||
|
||||
printf("mul\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, mul,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, mul,
|
||||
(BinaryCallback)shortint_server_key_smart_mul);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, mul,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, mul,
|
||||
(BinaryCallback)shortint_server_key_unchecked_mul);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, mul,
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, mul,
|
||||
(BinaryAssignCallback)shortint_server_key_smart_mul_assign);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, mul,
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, mul,
|
||||
(BinaryAssignCallback)shortint_server_key_unchecked_mul_assign);
|
||||
|
||||
printf("left_shift\n");
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, left_shift,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, left_shift,
|
||||
(BinaryScalarCallback)shortint_server_key_smart_scalar_left_shift, NULL, 0);
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, left_shift,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, left_shift,
|
||||
(BinaryScalarCallback)shortint_server_key_unchecked_scalar_left_shift, NULL, 0);
|
||||
test_shortint_binary_scalar_op_assign(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, left_shift,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, left_shift,
|
||||
shortint_server_key_smart_scalar_left_shift_assign, NULL, 0);
|
||||
test_shortint_binary_scalar_op_assign(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, left_shift,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, left_shift,
|
||||
shortint_server_key_unchecked_scalar_left_shift_assign, NULL, 0);
|
||||
|
||||
printf("right_shift\n");
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, right_shift,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, right_shift,
|
||||
(BinaryScalarCallback)shortint_server_key_smart_scalar_right_shift, NULL, 0);
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, right_shift,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, right_shift,
|
||||
(BinaryScalarCallback)shortint_server_key_unchecked_scalar_right_shift, NULL, 0);
|
||||
test_shortint_binary_scalar_op_assign(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, right_shift,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, right_shift,
|
||||
shortint_server_key_smart_scalar_right_shift_assign, NULL, 0);
|
||||
test_shortint_binary_scalar_op_assign(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, right_shift,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, right_shift,
|
||||
shortint_server_key_unchecked_scalar_right_shift_assign, NULL, 0);
|
||||
|
||||
printf("scalar_add\n");
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_add,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_add,
|
||||
(BinaryScalarCallback)shortint_server_key_smart_scalar_add, NULL, 0);
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_add,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_add,
|
||||
(BinaryScalarCallback)shortint_server_key_unchecked_scalar_add, NULL, 0);
|
||||
test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, scalar_add,
|
||||
shortint_server_key_smart_scalar_add_assign, NULL, 0);
|
||||
scalar_add, shortint_server_key_smart_scalar_add_assign,
|
||||
NULL, 0);
|
||||
test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, scalar_add,
|
||||
shortint_server_key_unchecked_scalar_add_assign, NULL, 0);
|
||||
scalar_add, shortint_server_key_unchecked_scalar_add_assign,
|
||||
NULL, 0);
|
||||
|
||||
printf("scalar_sub\n");
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_sub,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_sub,
|
||||
(BinaryScalarCallback)shortint_server_key_smart_scalar_sub, NULL, 0);
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_sub,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_sub,
|
||||
(BinaryScalarCallback)shortint_server_key_unchecked_scalar_sub, NULL, 0);
|
||||
test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, scalar_sub,
|
||||
shortint_server_key_smart_scalar_sub_assign, NULL, 0);
|
||||
scalar_sub, shortint_server_key_smart_scalar_sub_assign,
|
||||
NULL, 0);
|
||||
test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, scalar_sub,
|
||||
shortint_server_key_unchecked_scalar_sub_assign, NULL, 0);
|
||||
scalar_sub, shortint_server_key_unchecked_scalar_sub_assign,
|
||||
NULL, 0);
|
||||
|
||||
printf("scalar_mul\n");
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_mul,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_mul,
|
||||
(BinaryScalarCallback)shortint_server_key_smart_scalar_mul, NULL, 0);
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_mul,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_mul,
|
||||
(BinaryScalarCallback)shortint_server_key_unchecked_scalar_mul, NULL, 0);
|
||||
test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, scalar_mul,
|
||||
shortint_server_key_smart_scalar_mul_assign, NULL, 0);
|
||||
scalar_mul, shortint_server_key_smart_scalar_mul_assign,
|
||||
NULL, 0);
|
||||
test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, scalar_mul,
|
||||
shortint_server_key_unchecked_scalar_mul_assign, NULL, 0);
|
||||
scalar_mul, shortint_server_key_unchecked_scalar_mul_assign,
|
||||
NULL, 0);
|
||||
|
||||
printf("bitand\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, bitand, (BinaryCallback)shortint_server_key_smart_bitand);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, bitand, (BinaryCallback)shortint_server_key_unchecked_bitand);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, bitand,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitand,
|
||||
(BinaryCallback)shortint_server_key_smart_bitand);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitand,
|
||||
(BinaryCallback)shortint_server_key_unchecked_bitand);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitand,
|
||||
(BinaryAssignCallback)shortint_server_key_smart_bitand_assign);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, bitand,
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitand,
|
||||
(BinaryAssignCallback)shortint_server_key_unchecked_bitand_assign);
|
||||
|
||||
printf("bitxor\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
bitxor, (BinaryCallback)shortint_server_key_smart_bitxor);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
bitxor, (BinaryCallback)shortint_server_key_unchecked_bitxor);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, bitxor,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitxor,
|
||||
(BinaryCallback)shortint_server_key_smart_bitxor);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitxor,
|
||||
(BinaryCallback)shortint_server_key_unchecked_bitxor);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitxor,
|
||||
(BinaryAssignCallback)shortint_server_key_smart_bitxor_assign);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, bitxor,
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitxor,
|
||||
(BinaryAssignCallback)shortint_server_key_unchecked_bitxor_assign);
|
||||
|
||||
printf("bitor\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
bitor, (BinaryCallback)shortint_server_key_smart_bitor);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
bitor, (BinaryCallback)shortint_server_key_unchecked_bitor);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, bitor,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitor,
|
||||
(BinaryCallback)shortint_server_key_smart_bitor);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitor,
|
||||
(BinaryCallback)shortint_server_key_unchecked_bitor);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitor,
|
||||
(BinaryAssignCallback)shortint_server_key_smart_bitor_assign);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, bitor,
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitor,
|
||||
(BinaryAssignCallback)shortint_server_key_unchecked_bitor_assign);
|
||||
|
||||
printf("greater\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
greater, (BinaryCallback)shortint_server_key_smart_greater);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
greater, (BinaryCallback)shortint_server_key_unchecked_greater);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, greater,
|
||||
(BinaryCallback)shortint_server_key_smart_greater);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, greater,
|
||||
(BinaryCallback)shortint_server_key_unchecked_greater);
|
||||
|
||||
printf("greater_or_equal\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
greater_or_equal,
|
||||
(BinaryCallback)shortint_server_key_smart_greater_or_equal);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
greater_or_equal,
|
||||
(BinaryCallback)shortint_server_key_unchecked_greater_or_equal);
|
||||
|
||||
printf("less\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
less, (BinaryCallback)shortint_server_key_smart_less);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
less, (BinaryCallback)shortint_server_key_unchecked_less);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, less,
|
||||
(BinaryCallback)shortint_server_key_smart_less);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, less,
|
||||
(BinaryCallback)shortint_server_key_unchecked_less);
|
||||
|
||||
printf("less_or_equal\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
less_or_equal, (BinaryCallback)shortint_server_key_smart_less_or_equal);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
less_or_equal,
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, less_or_equal,
|
||||
(BinaryCallback)shortint_server_key_smart_less_or_equal);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, less_or_equal,
|
||||
(BinaryCallback)shortint_server_key_unchecked_less_or_equal);
|
||||
|
||||
printf("equal\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
equal, (BinaryCallback)shortint_server_key_smart_equal);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
equal, (BinaryCallback)shortint_server_key_unchecked_equal);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, equal,
|
||||
(BinaryCallback)shortint_server_key_smart_equal);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, equal,
|
||||
(BinaryCallback)shortint_server_key_unchecked_equal);
|
||||
|
||||
printf("not_equal\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
not_equal, (BinaryCallback)shortint_server_key_smart_not_equal);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
not_equal, (BinaryCallback)shortint_server_key_unchecked_not_equal);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, not_equal,
|
||||
(BinaryCallback)shortint_server_key_smart_not_equal);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, not_equal,
|
||||
(BinaryCallback)shortint_server_key_unchecked_not_equal);
|
||||
|
||||
printf("scalar_greater\n");
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_greater,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_greater,
|
||||
(BinaryScalarCallback)shortint_server_key_smart_scalar_greater, NULL, 0);
|
||||
|
||||
printf("scalar_greater_or_equal\n");
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_greater_or_equal,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_greater_or_equal,
|
||||
(BinaryScalarCallback)shortint_server_key_smart_scalar_greater_or_equal, NULL, 0);
|
||||
|
||||
printf("scalar_less\n");
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_less,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_less,
|
||||
(BinaryScalarCallback)shortint_server_key_smart_scalar_less, NULL, 0);
|
||||
|
||||
printf("scalar_less_or_equal\n");
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_less_or_equal,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_less_or_equal,
|
||||
(BinaryScalarCallback)shortint_server_key_smart_scalar_less_or_equal, NULL, 0);
|
||||
|
||||
printf("scalar_equal\n");
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_equal,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_equal,
|
||||
(BinaryScalarCallback)shortint_server_key_smart_scalar_equal, NULL, 0);
|
||||
|
||||
printf("scalar_not_equal\n");
|
||||
test_shortint_binary_scalar_op(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_not_equal,
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_not_equal,
|
||||
(BinaryScalarCallback)shortint_server_key_smart_scalar_not_equal, NULL, 0);
|
||||
|
||||
printf("neg\n");
|
||||
test_shortint_unary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, neg,
|
||||
test_shortint_unary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, neg,
|
||||
(UnaryCallback)shortint_server_key_smart_neg);
|
||||
test_shortint_unary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, neg,
|
||||
test_shortint_unary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, neg,
|
||||
(UnaryCallback)shortint_server_key_unchecked_neg);
|
||||
test_shortint_unary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, neg,
|
||||
test_shortint_unary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, neg,
|
||||
(UnaryAssignCallback)shortint_server_key_smart_neg_assign);
|
||||
test_shortint_unary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, neg,
|
||||
test_shortint_unary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, neg,
|
||||
(UnaryAssignCallback)shortint_server_key_unchecked_neg_assign);
|
||||
|
||||
printf("div\n");
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
homomorphic_div, (BinaryCallback)shortint_server_key_smart_div);
|
||||
test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits,
|
||||
homomorphic_div, (BinaryCallback)shortint_server_key_unchecked_div);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, homomorphic_div,
|
||||
(BinaryAssignCallback)shortint_server_key_smart_div_assign);
|
||||
test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, homomorphic_div,
|
||||
(BinaryAssignCallback)shortint_server_key_unchecked_div_assign);
|
||||
test_shortint_div(deser_cks, deser_sks, cks_small, sks_small, message_bits);
|
||||
test_shortint_div(deser_cks, deser_sks, cks_small, sks_small, message_bits);
|
||||
test_shortint_div_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits);
|
||||
test_shortint_div_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits);
|
||||
|
||||
printf("scalar_div\n");
|
||||
uint8_t forbidden_scalar_div_values[1] = {0};
|
||||
test_shortint_binary_scalar_op(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, scalar_div,
|
||||
scalar_div,
|
||||
(BinaryScalarCallback)shortint_server_key_unchecked_scalar_div,
|
||||
forbidden_scalar_div_values, 1);
|
||||
test_shortint_binary_scalar_op_assign(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_div,
|
||||
shortint_server_key_unchecked_scalar_div_assign, forbidden_scalar_div_values, 1);
|
||||
test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
scalar_div, shortint_server_key_unchecked_scalar_div_assign,
|
||||
forbidden_scalar_div_values, 1);
|
||||
printf("scalar_mod\n");
|
||||
uint8_t forbidden_scalar_mod_values[1] = {0};
|
||||
test_shortint_binary_scalar_op(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
carry_bits, scalar_mod,
|
||||
scalar_mod,
|
||||
(BinaryScalarCallback)shortint_server_key_unchecked_scalar_mod,
|
||||
forbidden_scalar_mod_values, 1);
|
||||
test_shortint_binary_scalar_op_assign(
|
||||
deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_mod,
|
||||
shortint_server_key_unchecked_scalar_mod_assign, forbidden_scalar_mod_values, 1);
|
||||
test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits,
|
||||
scalar_mod, shortint_server_key_unchecked_scalar_mod_assign,
|
||||
forbidden_scalar_mod_values, 1);
|
||||
|
||||
shortint_destroy_client_key(cks);
|
||||
shortint_destroy_client_key(cks_small);
|
||||
|
||||
@@ -2,10 +2,10 @@ use crate::shortint::ciphertext::Degree;
|
||||
use crate::shortint::engine::{EngineResult, ShortintEngine};
|
||||
use crate::shortint::{Ciphertext, ServerKey};
|
||||
|
||||
// Specific division function returning 0 in case of a division by 0
|
||||
pub(crate) fn safe_division(x: u64, y: u64) -> u64 {
|
||||
// Specific division function returning value_on_div_by_zero in case of a division by 0
|
||||
pub(crate) fn safe_division(x: u64, y: u64, value_on_div_by_zero: u64) -> u64 {
|
||||
if y == 0 {
|
||||
0
|
||||
value_on_div_by_zero
|
||||
} else {
|
||||
x / y
|
||||
}
|
||||
@@ -29,11 +29,12 @@ impl ShortintEngine {
|
||||
ct_left: &mut Ciphertext,
|
||||
ct_right: &Ciphertext,
|
||||
) -> EngineResult<()> {
|
||||
let value_on_div_by_zero = (ct_left.message_modulus.0 - 1) as u64;
|
||||
self.unchecked_evaluate_bivariate_function_assign(
|
||||
server_key,
|
||||
ct_left,
|
||||
ct_right,
|
||||
safe_division,
|
||||
|x, y| safe_division(x, y, value_on_div_by_zero),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -93,7 +94,7 @@ impl ShortintEngine {
|
||||
ct: &mut Ciphertext,
|
||||
scalar: u8,
|
||||
) -> EngineResult<()> {
|
||||
assert_ne!(scalar, 0);
|
||||
assert_ne!(scalar, 0, "attempt to divide by zero");
|
||||
let lookup_table = self.generate_lookup_table(server_key, |x| x / (scalar as u64))?;
|
||||
self.apply_lookup_table_assign(server_key, ct, &lookup_table)?;
|
||||
ct.degree = Degree(ct.degree.0 / scalar as usize);
|
||||
|
||||
@@ -9,7 +9,8 @@ impl ServerKey {
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// /!\ A division by zero returns 0!
|
||||
/// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of
|
||||
/// message it will therefore return 3.
|
||||
///
|
||||
/// This function, like all "default" operations (i.e. not smart, checked or unchecked), will
|
||||
/// check that the input ciphertext carries are empty and clears them if it's not the case and
|
||||
@@ -69,7 +70,8 @@ impl ServerKey {
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// /!\ A division by zero returns 0!
|
||||
/// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of
|
||||
/// message it will therefore return 3.
|
||||
///
|
||||
/// This function, like all "default" operations (i.e. not smart, checked or unchecked), will
|
||||
/// check that the input ciphertext carries are empty and clears them if it's not the case and
|
||||
@@ -140,7 +142,8 @@ impl ServerKey {
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// /!\ A division by zero returns 0!
|
||||
/// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of
|
||||
/// message it will therefore return 3.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@@ -192,7 +195,8 @@ impl ServerKey {
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// /!\ A division by zero returns 0!
|
||||
/// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of
|
||||
/// message it will therefore return 3.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@@ -246,7 +250,8 @@ impl ServerKey {
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// /!\ A division by zero returns 0!
|
||||
/// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of
|
||||
/// message it will therefore return 3.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@@ -298,7 +303,8 @@ impl ServerKey {
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// /!\ A division by zero returns 0!
|
||||
/// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of
|
||||
/// message it will therefore return 3.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
||||
@@ -2072,6 +2072,19 @@ where
|
||||
|
||||
let modulus = cks.parameters.message_modulus().0 as u64;
|
||||
|
||||
// check div by 0 result
|
||||
{
|
||||
let numerator = 1u64;
|
||||
let denominator = 0u64;
|
||||
|
||||
let ct_num = cks.encrypt(numerator);
|
||||
let ct_denom = cks.encrypt(denominator);
|
||||
let ct_res = sks.unchecked_div(&ct_num, &ct_denom);
|
||||
|
||||
let res = cks.decrypt(&ct_res);
|
||||
assert_eq!(res, (ct_num.message_modulus.0 - 1) as u64)
|
||||
}
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
let clear_1 = (rng.gen::<u64>() % (modulus - 1)) + 1;
|
||||
|
||||
Reference in New Issue
Block a user