From f982c585386d0a2cd9c628c2987647686cede649 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 12 Jul 2023 14:31:46 +0200 Subject: [PATCH] chore(shortint): make shortint div behavior match integer on div by zero --- tfhe/c_api_tests/test_shortint_server_key.c | 371 +++++++++++------- .../shortint/engine/server_side/div_mod.rs | 11 +- tfhe/src/shortint/server_key/div_mod.rs | 18 +- .../src/shortint/server_key/tests/shortint.rs | 13 + 4 files changed, 264 insertions(+), 149 deletions(-) diff --git a/tfhe/c_api_tests/test_shortint_server_key.c b/tfhe/c_api_tests/test_shortint_server_key.c index 18fb688fb..407aa17ba 100644 --- a/tfhe/c_api_tests/test_shortint_server_key.c +++ b/tfhe/c_api_tests/test_shortint_server_key.c @@ -21,8 +21,8 @@ typedef int (*UnaryAssignCallback)(const ShortintServerKey *, ShortintCiphertext void test_shortint_unary_op(const ShortintClientKey *cks, const ShortintServerKey *sks, const ShortintClientKey *cks_small, const ShortintServerKey *sks_small, - const uint32_t message_bits, const uint32_t carry_bits, - uint64_t (*c_fun)(uint64_t), UnaryCallback api_fun) { + const uint32_t message_bits, uint64_t (*c_fun)(uint64_t), + UnaryCallback api_fun) { int message_max = 1 << message_bits; @@ -70,8 +70,7 @@ void test_shortint_unary_op(const ShortintClientKey *cks, const ShortintServerKe void test_shortint_unary_op_assign(const ShortintClientKey *cks, const ShortintServerKey *sks, const ShortintClientKey *cks_small, const ShortintServerKey *sks_small, const uint32_t message_bits, - const uint32_t carry_bits, uint64_t (*c_fun)(uint64_t), - UnaryAssignCallback api_fun) { + uint64_t (*c_fun)(uint64_t), UnaryAssignCallback api_fun) { int message_max = 1 << message_bits; @@ -116,8 +115,8 @@ void test_shortint_unary_op_assign(const ShortintClientKey *cks, const ShortintS void test_shortint_binary_op(const ShortintClientKey *cks, const ShortintServerKey *sks, const ShortintClientKey *cks_small, const ShortintServerKey *sks_small, - const uint32_t message_bits, const uint32_t carry_bits, - uint64_t (*c_fun)(uint64_t, uint64_t), BinaryCallback api_fun) { + const uint32_t message_bits, uint64_t (*c_fun)(uint64_t, uint64_t), + BinaryCallback api_fun) { int message_max = 1 << message_bits; @@ -176,7 +175,7 @@ void test_shortint_binary_op(const ShortintClientKey *cks, const ShortintServerK void test_shortint_binary_op_assign(const ShortintClientKey *cks, const ShortintServerKey *sks, const ShortintClientKey *cks_small, const ShortintServerKey *sks_small, const uint32_t message_bits, - const uint32_t carry_bits, + uint64_t (*c_fun)(uint64_t, uint64_t), BinaryAssignCallback api_fun) { @@ -233,9 +232,135 @@ void test_shortint_binary_op_assign(const ShortintClientKey *cks, const Shortint } } +uint64_t homomorphic_div(uint64_t left, uint64_t right, uint64_t value_on_div_by_zero) { + if (right != 0) { + return left / right; + } else { + // Special value chosen in the shortint implementation in case of a division by 0 + return value_on_div_by_zero; + } +} + +void test_shortint_div(const ShortintClientKey *cks, const ShortintServerKey *sks, + const ShortintClientKey *cks_small, const ShortintServerKey *sks_small, + const uint32_t message_bits) { + + int message_max = 1 << message_bits; + + for (int is_big = 0; is_big < 2; ++is_big) { + for (int val_left = 0; val_left < message_max; ++val_left) { + for (int val_right = 0; val_right < message_max; ++val_right) { + ShortintCiphertext *ct_left = NULL; + ShortintCiphertext *ct_right = NULL; + ShortintCiphertext *ct_result = NULL; + const ShortintClientKey *cks_in_use = NULL; + const ShortintServerKey *sks_in_use = NULL; + + uint64_t left = (uint64_t)val_left; + uint64_t right = (uint64_t)val_right; + + uint64_t expected = homomorphic_div(left, right, (uint64_t)(message_max - 1)) % message_max; + + if (is_big == 1) { + cks_in_use = cks; + sks_in_use = sks; + + int encrypt_left_ok = shortint_client_key_encrypt(cks_in_use, left, &ct_left); + assert(encrypt_left_ok == 0); + + int encrypt_right_ok = shortint_client_key_encrypt(cks_in_use, right, &ct_right); + assert(encrypt_right_ok == 0); + } else { + cks_in_use = cks_small; + sks_in_use = sks_small; + + int encrypt_left_ok = shortint_client_key_encrypt(cks_in_use, left, &ct_left); + assert(encrypt_left_ok == 0); + + int encrypt_right_ok = shortint_client_key_encrypt(cks_in_use, right, &ct_right); + assert(encrypt_right_ok == 0); + } + + int api_call_ok = + shortint_server_key_unchecked_div(sks_in_use, ct_left, ct_right, &ct_result); + assert(api_call_ok == 0); + + uint64_t decrypted_result = -1; + + int decrypt_ok = shortint_client_key_decrypt(cks_in_use, ct_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + shortint_destroy_ciphertext(ct_left); + shortint_destroy_ciphertext(ct_right); + shortint_destroy_ciphertext(ct_result); + } + } + } +} + +void test_shortint_div_assign(const ShortintClientKey *cks, const ShortintServerKey *sks, + const ShortintClientKey *cks_small, + const ShortintServerKey *sks_small, const uint32_t message_bits) { + + int message_max = 1 << message_bits; + + for (int is_big = 0; is_big < 2; ++is_big) { + for (int val_left = 0; val_left < message_max; ++val_left) { + for (int val_right = 0; val_right < message_max; ++val_right) { + ShortintCiphertext *ct_left_and_result = NULL; + ShortintCiphertext *ct_right = NULL; + const ShortintClientKey *cks_in_use = NULL; + const ShortintServerKey *sks_in_use = NULL; + + uint64_t left = (uint64_t)val_left; + uint64_t right = (uint64_t)val_right; + + uint64_t expected = homomorphic_div(left, right, (uint64_t)(message_max - 1)) % message_max; + + if (is_big == 1) { + cks_in_use = cks; + sks_in_use = sks; + + int encrypt_left_ok = shortint_client_key_encrypt(cks_in_use, left, &ct_left_and_result); + assert(encrypt_left_ok == 0); + + int encrypt_right_ok = shortint_client_key_encrypt(cks_in_use, right, &ct_right); + assert(encrypt_right_ok == 0); + } else { + cks_in_use = cks_small; + sks_in_use = sks_small; + + int encrypt_left_ok = shortint_client_key_encrypt(cks_in_use, left, &ct_left_and_result); + assert(encrypt_left_ok == 0); + + int encrypt_right_ok = shortint_client_key_encrypt(cks_in_use, right, &ct_right); + assert(encrypt_right_ok == 0); + } + + int api_call_ok = + shortint_server_key_unchecked_div_assign(sks_in_use, ct_left_and_result, ct_right); + assert(api_call_ok == 0); + + uint64_t decrypted_result = -1; + + int decrypt_ok = + shortint_client_key_decrypt(cks_in_use, ct_left_and_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + shortint_destroy_ciphertext(ct_left_and_result); + shortint_destroy_ciphertext(ct_right); + } + } + } +} + void test_shortint_binary_scalar_op( const ShortintClientKey *cks, const ShortintServerKey *sks, const ShortintClientKey *cks_small, - const ShortintServerKey *sks_small, const uint32_t message_bits, const uint32_t carry_bits, + const ShortintServerKey *sks_small, const uint32_t message_bits, uint64_t (*c_fun)(uint64_t, uint8_t), int (*api_fun)(const ShortintServerKey *, ShortintCiphertext *, uint8_t, ShortintCiphertext **), uint8_t forbidden_scalar_values[], size_t forbidden_scalar_values_len) { @@ -302,7 +427,7 @@ void test_shortint_binary_scalar_op( void test_shortint_binary_scalar_op_assign( const ShortintClientKey *cks, const ShortintServerKey *sks, const ShortintClientKey *cks_small, - const ShortintServerKey *sks_small, const uint32_t message_bits, const uint32_t carry_bits, + const ShortintServerKey *sks_small, const uint32_t message_bits, uint64_t (*c_fun)(uint64_t, uint8_t), int (*api_fun)(const ShortintServerKey *, ShortintCiphertext *, uint8_t), uint8_t forbidden_scalar_values[], size_t forbidden_scalar_values_len) { @@ -371,15 +496,6 @@ uint64_t sub(uint64_t left, uint64_t right) { return left - right; } uint64_t mul(uint64_t left, uint64_t right) { return left * right; } uint64_t neg(uint64_t in) { return -in; } -uint64_t homomorphic_div(uint64_t left, uint64_t right) { - if (right != 0) { - return left / right; - } else { - // Special value chosen in the shortint implementation in case of a division by 0 - return 0; - } -} - uint64_t bitand(uint64_t left, uint64_t right) { return left & right; } uint64_t bitxor(uint64_t left, uint64_t right) { return left ^ right; } uint64_t bitor (uint64_t left, uint64_t right) { return left | right; } @@ -473,258 +589,237 @@ void test_server_key(void) { assert(deser_sks_ok == 0); printf("add\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, add, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, add, (BinaryCallback)shortint_server_key_smart_add); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, add, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, add, (BinaryCallback)shortint_server_key_unchecked_add); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, add, + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, add, (BinaryAssignCallback)shortint_server_key_smart_add_assign); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, add, + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, add, (BinaryAssignCallback)shortint_server_key_unchecked_add_assign); printf("sub\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, sub, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, sub, (BinaryCallback)shortint_server_key_smart_sub); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, sub, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, sub, (BinaryCallback)shortint_server_key_unchecked_sub); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, sub, + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, sub, (BinaryAssignCallback)shortint_server_key_smart_sub_assign); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, sub, + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, sub, (BinaryAssignCallback)shortint_server_key_unchecked_sub_assign); printf("mul\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, mul, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, mul, (BinaryCallback)shortint_server_key_smart_mul); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, mul, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, mul, (BinaryCallback)shortint_server_key_unchecked_mul); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, mul, + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, mul, (BinaryAssignCallback)shortint_server_key_smart_mul_assign); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, mul, + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, mul, (BinaryAssignCallback)shortint_server_key_unchecked_mul_assign); printf("left_shift\n"); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, left_shift, + deser_cks, deser_sks, cks_small, sks_small, message_bits, left_shift, (BinaryScalarCallback)shortint_server_key_smart_scalar_left_shift, NULL, 0); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, left_shift, + deser_cks, deser_sks, cks_small, sks_small, message_bits, left_shift, (BinaryScalarCallback)shortint_server_key_unchecked_scalar_left_shift, NULL, 0); test_shortint_binary_scalar_op_assign( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, left_shift, + deser_cks, deser_sks, cks_small, sks_small, message_bits, left_shift, shortint_server_key_smart_scalar_left_shift_assign, NULL, 0); test_shortint_binary_scalar_op_assign( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, left_shift, + deser_cks, deser_sks, cks_small, sks_small, message_bits, left_shift, shortint_server_key_unchecked_scalar_left_shift_assign, NULL, 0); printf("right_shift\n"); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, right_shift, + deser_cks, deser_sks, cks_small, sks_small, message_bits, right_shift, (BinaryScalarCallback)shortint_server_key_smart_scalar_right_shift, NULL, 0); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, right_shift, + deser_cks, deser_sks, cks_small, sks_small, message_bits, right_shift, (BinaryScalarCallback)shortint_server_key_unchecked_scalar_right_shift, NULL, 0); test_shortint_binary_scalar_op_assign( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, right_shift, + deser_cks, deser_sks, cks_small, sks_small, message_bits, right_shift, shortint_server_key_smart_scalar_right_shift_assign, NULL, 0); test_shortint_binary_scalar_op_assign( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, right_shift, + deser_cks, deser_sks, cks_small, sks_small, message_bits, right_shift, shortint_server_key_unchecked_scalar_right_shift_assign, NULL, 0); printf("scalar_add\n"); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_add, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_add, (BinaryScalarCallback)shortint_server_key_smart_scalar_add, NULL, 0); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_add, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_add, (BinaryScalarCallback)shortint_server_key_unchecked_scalar_add, NULL, 0); test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, scalar_add, - shortint_server_key_smart_scalar_add_assign, NULL, 0); + scalar_add, shortint_server_key_smart_scalar_add_assign, + NULL, 0); test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, scalar_add, - shortint_server_key_unchecked_scalar_add_assign, NULL, 0); + scalar_add, shortint_server_key_unchecked_scalar_add_assign, + NULL, 0); printf("scalar_sub\n"); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_sub, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_sub, (BinaryScalarCallback)shortint_server_key_smart_scalar_sub, NULL, 0); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_sub, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_sub, (BinaryScalarCallback)shortint_server_key_unchecked_scalar_sub, NULL, 0); test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, scalar_sub, - shortint_server_key_smart_scalar_sub_assign, NULL, 0); + scalar_sub, shortint_server_key_smart_scalar_sub_assign, + NULL, 0); test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, scalar_sub, - shortint_server_key_unchecked_scalar_sub_assign, NULL, 0); + scalar_sub, shortint_server_key_unchecked_scalar_sub_assign, + NULL, 0); printf("scalar_mul\n"); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_mul, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_mul, (BinaryScalarCallback)shortint_server_key_smart_scalar_mul, NULL, 0); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_mul, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_mul, (BinaryScalarCallback)shortint_server_key_unchecked_scalar_mul, NULL, 0); test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, scalar_mul, - shortint_server_key_smart_scalar_mul_assign, NULL, 0); + scalar_mul, shortint_server_key_smart_scalar_mul_assign, + NULL, 0); test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, scalar_mul, - shortint_server_key_unchecked_scalar_mul_assign, NULL, 0); + scalar_mul, shortint_server_key_unchecked_scalar_mul_assign, + NULL, 0); printf("bitand\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, bitand, (BinaryCallback)shortint_server_key_smart_bitand); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, bitand, (BinaryCallback)shortint_server_key_unchecked_bitand); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, bitand, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitand, + (BinaryCallback)shortint_server_key_smart_bitand); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitand, + (BinaryCallback)shortint_server_key_unchecked_bitand); + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitand, (BinaryAssignCallback)shortint_server_key_smart_bitand_assign); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, bitand, + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitand, (BinaryAssignCallback)shortint_server_key_unchecked_bitand_assign); printf("bitxor\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - bitxor, (BinaryCallback)shortint_server_key_smart_bitxor); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - bitxor, (BinaryCallback)shortint_server_key_unchecked_bitxor); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, bitxor, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitxor, + (BinaryCallback)shortint_server_key_smart_bitxor); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitxor, + (BinaryCallback)shortint_server_key_unchecked_bitxor); + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitxor, (BinaryAssignCallback)shortint_server_key_smart_bitxor_assign); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, bitxor, + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitxor, (BinaryAssignCallback)shortint_server_key_unchecked_bitxor_assign); printf("bitor\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - bitor, (BinaryCallback)shortint_server_key_smart_bitor); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - bitor, (BinaryCallback)shortint_server_key_unchecked_bitor); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, bitor, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitor, + (BinaryCallback)shortint_server_key_smart_bitor); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitor, + (BinaryCallback)shortint_server_key_unchecked_bitor); + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitor, (BinaryAssignCallback)shortint_server_key_smart_bitor_assign); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, bitor, + test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, bitor, (BinaryAssignCallback)shortint_server_key_unchecked_bitor_assign); printf("greater\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - greater, (BinaryCallback)shortint_server_key_smart_greater); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - greater, (BinaryCallback)shortint_server_key_unchecked_greater); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, greater, + (BinaryCallback)shortint_server_key_smart_greater); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, greater, + (BinaryCallback)shortint_server_key_unchecked_greater); printf("greater_or_equal\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, greater_or_equal, (BinaryCallback)shortint_server_key_smart_greater_or_equal); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, greater_or_equal, (BinaryCallback)shortint_server_key_unchecked_greater_or_equal); printf("less\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - less, (BinaryCallback)shortint_server_key_smart_less); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - less, (BinaryCallback)shortint_server_key_unchecked_less); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, less, + (BinaryCallback)shortint_server_key_smart_less); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, less, + (BinaryCallback)shortint_server_key_unchecked_less); printf("less_or_equal\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - less_or_equal, (BinaryCallback)shortint_server_key_smart_less_or_equal); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - less_or_equal, + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, less_or_equal, + (BinaryCallback)shortint_server_key_smart_less_or_equal); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, less_or_equal, (BinaryCallback)shortint_server_key_unchecked_less_or_equal); printf("equal\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - equal, (BinaryCallback)shortint_server_key_smart_equal); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - equal, (BinaryCallback)shortint_server_key_unchecked_equal); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, equal, + (BinaryCallback)shortint_server_key_smart_equal); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, equal, + (BinaryCallback)shortint_server_key_unchecked_equal); printf("not_equal\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - not_equal, (BinaryCallback)shortint_server_key_smart_not_equal); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - not_equal, (BinaryCallback)shortint_server_key_unchecked_not_equal); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, not_equal, + (BinaryCallback)shortint_server_key_smart_not_equal); + test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, not_equal, + (BinaryCallback)shortint_server_key_unchecked_not_equal); printf("scalar_greater\n"); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_greater, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_greater, (BinaryScalarCallback)shortint_server_key_smart_scalar_greater, NULL, 0); printf("scalar_greater_or_equal\n"); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_greater_or_equal, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_greater_or_equal, (BinaryScalarCallback)shortint_server_key_smart_scalar_greater_or_equal, NULL, 0); printf("scalar_less\n"); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_less, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_less, (BinaryScalarCallback)shortint_server_key_smart_scalar_less, NULL, 0); printf("scalar_less_or_equal\n"); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_less_or_equal, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_less_or_equal, (BinaryScalarCallback)shortint_server_key_smart_scalar_less_or_equal, NULL, 0); printf("scalar_equal\n"); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_equal, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_equal, (BinaryScalarCallback)shortint_server_key_smart_scalar_equal, NULL, 0); printf("scalar_not_equal\n"); test_shortint_binary_scalar_op( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_not_equal, + deser_cks, deser_sks, cks_small, sks_small, message_bits, scalar_not_equal, (BinaryScalarCallback)shortint_server_key_smart_scalar_not_equal, NULL, 0); printf("neg\n"); - test_shortint_unary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, neg, + test_shortint_unary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, neg, (UnaryCallback)shortint_server_key_smart_neg); - test_shortint_unary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, neg, + test_shortint_unary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, neg, (UnaryCallback)shortint_server_key_unchecked_neg); - test_shortint_unary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, neg, + test_shortint_unary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, neg, (UnaryAssignCallback)shortint_server_key_smart_neg_assign); - test_shortint_unary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, neg, + test_shortint_unary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, neg, (UnaryAssignCallback)shortint_server_key_unchecked_neg_assign); printf("div\n"); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - homomorphic_div, (BinaryCallback)shortint_server_key_smart_div); - test_shortint_binary_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, - homomorphic_div, (BinaryCallback)shortint_server_key_unchecked_div); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, homomorphic_div, - (BinaryAssignCallback)shortint_server_key_smart_div_assign); - test_shortint_binary_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, homomorphic_div, - (BinaryAssignCallback)shortint_server_key_unchecked_div_assign); + test_shortint_div(deser_cks, deser_sks, cks_small, sks_small, message_bits); + test_shortint_div(deser_cks, deser_sks, cks_small, sks_small, message_bits); + test_shortint_div_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits); + test_shortint_div_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits); printf("scalar_div\n"); uint8_t forbidden_scalar_div_values[1] = {0}; test_shortint_binary_scalar_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, scalar_div, + scalar_div, (BinaryScalarCallback)shortint_server_key_unchecked_scalar_div, forbidden_scalar_div_values, 1); - test_shortint_binary_scalar_op_assign( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_div, - shortint_server_key_unchecked_scalar_div_assign, forbidden_scalar_div_values, 1); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, + scalar_div, shortint_server_key_unchecked_scalar_div_assign, + forbidden_scalar_div_values, 1); printf("scalar_mod\n"); uint8_t forbidden_scalar_mod_values[1] = {0}; test_shortint_binary_scalar_op(deser_cks, deser_sks, cks_small, sks_small, message_bits, - carry_bits, scalar_mod, + scalar_mod, (BinaryScalarCallback)shortint_server_key_unchecked_scalar_mod, forbidden_scalar_mod_values, 1); - test_shortint_binary_scalar_op_assign( - deser_cks, deser_sks, cks_small, sks_small, message_bits, carry_bits, scalar_mod, - shortint_server_key_unchecked_scalar_mod_assign, forbidden_scalar_mod_values, 1); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, cks_small, sks_small, message_bits, + scalar_mod, shortint_server_key_unchecked_scalar_mod_assign, + forbidden_scalar_mod_values, 1); shortint_destroy_client_key(cks); shortint_destroy_client_key(cks_small); diff --git a/tfhe/src/shortint/engine/server_side/div_mod.rs b/tfhe/src/shortint/engine/server_side/div_mod.rs index 56bd14312..f55cdab42 100644 --- a/tfhe/src/shortint/engine/server_side/div_mod.rs +++ b/tfhe/src/shortint/engine/server_side/div_mod.rs @@ -2,10 +2,10 @@ use crate::shortint::ciphertext::Degree; use crate::shortint::engine::{EngineResult, ShortintEngine}; use crate::shortint::{Ciphertext, ServerKey}; -// Specific division function returning 0 in case of a division by 0 -pub(crate) fn safe_division(x: u64, y: u64) -> u64 { +// Specific division function returning value_on_div_by_zero in case of a division by 0 +pub(crate) fn safe_division(x: u64, y: u64, value_on_div_by_zero: u64) -> u64 { if y == 0 { - 0 + value_on_div_by_zero } else { x / y } @@ -29,11 +29,12 @@ impl ShortintEngine { ct_left: &mut Ciphertext, ct_right: &Ciphertext, ) -> EngineResult<()> { + let value_on_div_by_zero = (ct_left.message_modulus.0 - 1) as u64; self.unchecked_evaluate_bivariate_function_assign( server_key, ct_left, ct_right, - safe_division, + |x, y| safe_division(x, y, value_on_div_by_zero), )?; Ok(()) } @@ -93,7 +94,7 @@ impl ShortintEngine { ct: &mut Ciphertext, scalar: u8, ) -> EngineResult<()> { - assert_ne!(scalar, 0); + assert_ne!(scalar, 0, "attempt to divide by zero"); let lookup_table = self.generate_lookup_table(server_key, |x| x / (scalar as u64))?; self.apply_lookup_table_assign(server_key, ct, &lookup_table)?; ct.degree = Degree(ct.degree.0 / scalar as usize); diff --git a/tfhe/src/shortint/server_key/div_mod.rs b/tfhe/src/shortint/server_key/div_mod.rs index 5dc49745c..3086b415c 100644 --- a/tfhe/src/shortint/server_key/div_mod.rs +++ b/tfhe/src/shortint/server_key/div_mod.rs @@ -9,7 +9,8 @@ impl ServerKey { /// /// # Warning /// - /// /!\ A division by zero returns 0! + /// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of + /// message it will therefore return 3. /// /// This function, like all "default" operations (i.e. not smart, checked or unchecked), will /// check that the input ciphertext carries are empty and clears them if it's not the case and @@ -69,7 +70,8 @@ impl ServerKey { /// /// # Warning /// - /// /!\ A division by zero returns 0! + /// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of + /// message it will therefore return 3. /// /// This function, like all "default" operations (i.e. not smart, checked or unchecked), will /// check that the input ciphertext carries are empty and clears them if it's not the case and @@ -140,7 +142,8 @@ impl ServerKey { /// /// # Warning /// - /// /!\ A division by zero returns 0! + /// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of + /// message it will therefore return 3. /// /// # Example /// @@ -192,7 +195,8 @@ impl ServerKey { /// /// # Warning /// - /// /!\ A division by zero returns 0! + /// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of + /// message it will therefore return 3. /// /// # Example /// @@ -246,7 +250,8 @@ impl ServerKey { /// /// # Warning /// - /// /!\ A division by zero returns 0! + /// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of + /// message it will therefore return 3. /// /// # Example /// @@ -298,7 +303,8 @@ impl ServerKey { /// /// # Warning /// - /// /!\ A division by zero returns 0! + /// /!\ A division by zero returns the input ciphertext maximum message value! For 2 bits of + /// message it will therefore return 3. /// /// # Example /// diff --git a/tfhe/src/shortint/server_key/tests/shortint.rs b/tfhe/src/shortint/server_key/tests/shortint.rs index 6aa785c73..490b8a9f2 100644 --- a/tfhe/src/shortint/server_key/tests/shortint.rs +++ b/tfhe/src/shortint/server_key/tests/shortint.rs @@ -2072,6 +2072,19 @@ where let modulus = cks.parameters.message_modulus().0 as u64; + // check div by 0 result + { + let numerator = 1u64; + let denominator = 0u64; + + let ct_num = cks.encrypt(numerator); + let ct_denom = cks.encrypt(denominator); + let ct_res = sks.unchecked_div(&ct_num, &ct_denom); + + let res = cks.decrypt(&ct_res); + assert_eq!(res, (ct_num.message_modulus.0 - 1) as u64) + } + for _ in 0..NB_TEST { let clear_0 = rng.gen::() % modulus; let clear_1 = (rng.gen::() % (modulus - 1)) + 1;