From 02e6d3c95543bd13e33725a4b63995e5a4852e11 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 30 Jan 2023 16:46:57 +0100 Subject: [PATCH] feat(c_api): expose create_trivial for shortint in C api --- tfhe/c_api_tests/test_shortint_keygen.c | 28 +++++++++++++++++++++ tfhe/src/c_api/shortint/server_key/mod.rs | 19 ++++++++++++++ tfhe/src/shortint/engine/server_side/mod.rs | 4 +-- tfhe/src/shortint/server_key/mod.rs | 4 +-- 4 files changed, 51 insertions(+), 4 deletions(-) diff --git a/tfhe/c_api_tests/test_shortint_keygen.c b/tfhe/c_api_tests/test_shortint_keygen.c index 45ee85bbf..74180db08 100644 --- a/tfhe/c_api_tests/test_shortint_keygen.c +++ b/tfhe/c_api_tests/test_shortint_keygen.c @@ -76,6 +76,33 @@ void test_predefined_keygen_w_serde(void) { destroy_buffer(&ct_ser_buffer); } +void test_server_key_trivial_encrypt(void) { + ShortintClientKey *cks = NULL; + ShortintServerKey *sks = NULL; + ShortintParameters *params = NULL; + ShortintCiphertext *ct = NULL; + + int get_params_ok = shortint_get_parameters(2, 2, ¶ms); + assert(get_params_ok == 0); + + int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks); + assert(gen_keys_ok == 0); + + int encrypt_ok = shortint_server_key_create_trivial(sks, 3, &ct); + assert(encrypt_ok == 0); + + uint64_t result = -1; + int decrypt_ok = shortint_client_key_decrypt(cks, ct, &result); + assert(decrypt_ok == 0); + + assert(result == 3); + + destroy_shortint_client_key(cks); + destroy_shortint_server_key(sks); + destroy_shortint_parameters(params); + destroy_shortint_ciphertext(ct); +} + void test_custom_keygen(void) { ShortintClientKey *cks = NULL; ShortintServerKey *sks = NULL; @@ -188,5 +215,6 @@ int main(void) { test_custom_keygen(); test_public_keygen(); test_compressed_public_keygen(); + test_server_key_trivial_encrypt(); return EXIT_SUCCESS; } diff --git a/tfhe/src/c_api/shortint/server_key/mod.rs b/tfhe/src/c_api/shortint/server_key/mod.rs index cb6e8a758..a48165278 100644 --- a/tfhe/src/c_api/shortint/server_key/mod.rs +++ b/tfhe/src/c_api/shortint/server_key/mod.rs @@ -46,6 +46,25 @@ pub unsafe extern "C" fn shortint_gen_server_key( }) } +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_create_trivial( + server_key: *const ShortintServerKey, + value_to_trivially_encrypt: u64, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + + let heap_allocated_ciphertext = Box::new(ShortintCiphertext( + server_key.0.create_trivial(value_to_trivially_encrypt), + )); + + *result = Box::into_raw(heap_allocated_ciphertext); + }) +} + #[no_mangle] pub unsafe extern "C" fn shortint_serialize_server_key( server_key: *const ShortintServerKey, diff --git a/tfhe/src/shortint/engine/server_side/mod.rs b/tfhe/src/shortint/engine/server_side/mod.rs index 5bc40a8ae..4cc10a399 100644 --- a/tfhe/src/shortint/engine/server_side/mod.rs +++ b/tfhe/src/shortint/engine/server_side/mod.rs @@ -467,7 +467,7 @@ impl ShortintEngine { pub(crate) fn create_trivial( &mut self, server_key: &ServerKey, - value: u8, + value: u64, ) -> EngineResult { let lwe_size = server_key .bootstrapping_key @@ -499,7 +499,7 @@ impl ShortintEngine { &mut self, server_key: &ServerKey, ct: &mut Ciphertext, - value: u8, + value: u64, ) -> EngineResult<()> { let modular_value = value as usize % server_key.message_modulus.0; diff --git a/tfhe/src/shortint/server_key/mod.rs b/tfhe/src/shortint/server_key/mod.rs index 52d8164a7..ab332752b 100644 --- a/tfhe/src/shortint/server_key/mod.rs +++ b/tfhe/src/shortint/server_key/mod.rs @@ -534,11 +534,11 @@ impl ServerKey { /// let ct_res = cks.decrypt(&ct1); /// assert_eq!(1, ct_res); /// ``` - pub fn create_trivial(&self, value: u8) -> Ciphertext { + pub fn create_trivial(&self, value: u64) -> Ciphertext { ShortintEngine::with_thread_local_mut(|engine| engine.create_trivial(self, value).unwrap()) } - pub fn create_trivial_assign(&self, ct: &mut Ciphertext, value: u8) { + pub fn create_trivial_assign(&self, ct: &mut Ciphertext, value: u64) { ShortintEngine::with_thread_local_mut(|engine| { engine.create_trivial_assign(self, ct, value).unwrap() })